Coverage for pySDC/core/convergence_controller.py: 84%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import logging 

2from pySDC.helpers.pysdc_helper import FrozenClass 

3 

4 

5# short helper class to add params as attributes 

6class Pars(FrozenClass): 

7 def __init__(self, params): 

8 self.control_order = 0 # integer that determines the order in which the convergence controllers are called 

9 self.useMPI = None # depends on the controller 

10 

11 for k, v in params.items(): 

12 setattr(self, k, v) 

13 

14 self._freeze() 

15 

16 

17# short helper class to store status variables 

18class Status(FrozenClass): 

19 """ 

20 Initialize status variables with None, since at the time of instantiation of the convergence controllers, not all 

21 relevant information about the controller are known. 

22 """ 

23 

24 def __init__(self, status_variabes): 

25 [setattr(self, key, None) for key in status_variabes] 

26 

27 self._freeze() 

28 

29 

30class ConvergenceController(object): 

31 """ 

32 Base abstract class for convergence controller, which is plugged into the controller to determine the iteration 

33 count and time step size. 

34 """ 

35 

36 def __init__(self, controller, params, description, **kwargs): 

37 """ 

38 Initialization routine 

39 

40 Args: 

41 controller (pySDC.Controller): The controller 

42 params (dict): The params passed for this specific convergence controller 

43 description (dict): The description object used to instantiate the controller 

44 """ 

45 self.controller = controller 

46 self.params = Pars(self.setup(controller, params, description)) 

47 params_ok, msg = self.check_parameters(controller, params, description) 

48 assert params_ok, f'{type(self).__name__} -- {msg}' 

49 self.dependencies(controller, description) 

50 self.logger = logging.getLogger(f"{type(self).__name__}") 

51 

52 if self.params.useMPI: 

53 self.prepare_MPI_datatypes() 

54 

55 def prepare_MPI_logical_operations(self): 

56 """ 

57 Prepare MPI logical operations so we don't need to import mpi4py all the time 

58 """ 

59 from mpi4py import MPI 

60 

61 self.MPI_LAND = MPI.LAND 

62 self.MPI_LOR = MPI.LOR 

63 

64 def prepare_MPI_datatypes(self): 

65 """ 

66 Prepare MPI datatypes so we don't need to import mpi4py all the time 

67 """ 

68 from mpi4py import MPI 

69 

70 self.MPI_INT = MPI.INT 

71 self.MPI_DOUBLE = MPI.DOUBLE 

72 self.MPI_BOOL = MPI.BOOL 

73 

74 def log(self, msg, S, level=15, **kwargs): 

75 """ 

76 Shortcut that has a default level for the logger. 15 is above debug but below info. 

77 

78 Args: 

79 msg (str): Message you want to log 

80 S (pySDC.step): The current step 

81 level (int): the level passed to the logger 

82 

83 Returns: 

84 None 

85 """ 

86 self.logger.log(level, f'Process {S.status.slot:2d} on time {S.time:.6f} - {msg}') 

87 return None 

88 

89 def debug(self, msg, S, **kwargs): 

90 """ 

91 Shortcut to pass messages at debug level to the logger. 

92 

93 Args: 

94 msg (str): Message you want to log 

95 S (pySDC.step): The current step 

96 

97 Returns: 

98 None 

99 """ 

100 self.log(msg=msg, S=S, level=10, **kwargs) 

101 return None 

102 

103 def setup(self, controller, params, description, **kwargs): 

104 """ 

105 Setup various variables that only need to be set once in the beginning. 

106 If the convergence controller is added automatically, you can give it params by adding it manually. 

107 It will be instantiated only once with the manually supplied parameters overriding automatically added 

108 parameters. 

109 

110 This function scans the convergence controllers supplied to the description object for instances of itself. 

111 This corresponds to the convergence controller being added manually by the user. If something is found, this 

112 function will then return a composite dictionary from the `params` passed to this function as well as the 

113 `params` passed manually, with priority to manually added parameters. If you added the convergence controller 

114 manually, that is of course the same and nothing happens. If, on the other hand, the convergence controller was 

115 added automatically, the `params` passed here will come from whatever added it and you can now override 

116 parameters by adding the convergence controller manually. 

117 This relies on children classes to return a composite dictionary from their defaults and from the result of this 

118 function, so you should write 

119 ``` 

120 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

121 ``` 

122 when overloading this method in a child class, with `defaults` a dictionary containing default parameters. 

123 

124 Args: 

125 controller (pySDC.Controller): The controller 

126 params (dict): The params passed for this specific convergence controller 

127 description (dict): The description object used to instantiate the controller 

128 

129 Returns: 

130 (dict): The updated params dictionary after setup 

131 """ 

132 # allow to change parameters by adding the convergence controller manually 

133 return {**params, **description.get('convergence_controllers', {}).get(type(self), {})} 

134 

135 def dependencies(self, controller, description, **kwargs): 

136 """ 

137 Load dependencies on other convergence controllers here. 

138 

139 Args: 

140 controller (pySDC.Controller): The controller 

141 description (dict): The description object used to instantiate the controller 

142 

143 Returns: 

144 None 

145 """ 

146 pass 

147 

148 def check_parameters(self, controller, params, description, **kwargs): 

149 """ 

150 Check whether parameters are compatible with whatever assumptions went into the step size functions etc. 

151 

152 Args: 

153 controller (pySDC.Controller): The controller 

154 params (dict): The params passed for this specific convergence controller 

155 description (dict): The description object used to instantiate the controller 

156 

157 Returns: 

158 bool: Whether the parameters are compatible 

159 str: The error message 

160 """ 

161 return True, "" 

162 

163 def check_iteration_status(self, controller, S, **kwargs): 

164 """ 

165 Determine whether to keep iterating or not in this function. 

166 

167 Args: 

168 controller (pySDC.Controller): The controller 

169 S (pySDC.Step): The current step 

170 

171 Returns: 

172 None 

173 """ 

174 pass 

175 

176 def get_new_step_size(self, controller, S, **kwargs): 

177 """ 

178 This function allows to set a step size with arbitrary criteria. 

179 Make sure to give an order to the convergence controller by setting the `control_order` variable in the params. 

180 This variable is an integer and you can see what the current order is by using 

181 `controller.print_convergence_controllers()`. 

182 

183 Args: 

184 controller (pySDC.Controller): The controller 

185 S (pySDC.Step): The current step 

186 

187 Returns: 

188 None 

189 """ 

190 pass 

191 

192 def determine_restart(self, controller, S, **kwargs): 

193 """ 

194 Determine for each step separately if it wants to be restarted for whatever reason. 

195 

196 Args: 

197 controller (pySDC.Controller): The controller 

198 S (pySDC.Step): The current step 

199 

200 Returns: 

201 None 

202 """ 

203 pass 

204 

205 def reset_status_variables(self, controller, **kwargs): 

206 """ 

207 Reset status variables. 

208 This is called in the `restart_block` function. 

209 Args: 

210 controller (pySDC.Controller): The controller 

211 

212 Returns: 

213 None 

214 """ 

215 return None 

216 

217 def setup_status_variables(self, controller, **kwargs): 

218 """ 

219 Setup status variables. 

220 This is not done at the time of instantiation, since the controller is not fully instantiated at that time and 

221 hence not all information are available. Instead, this function is called after the controller has been fully 

222 instantiated. 

223 

224 Args: 

225 controller (pySDC.Controller): The controller 

226 

227 Returns: 

228 None 

229 """ 

230 return None 

231 

232 def reset_buffers_nonMPI(self, controller, **kwargs): 

233 """ 

234 Buffers refer to variables used across multiple steps that are stored in the convergence controller classes to 

235 imitate communication in non MPI versions. These have to be reset in order to replicate availability of 

236 variables in MPI versions. 

237 

238 For instance, if step 0 sets self.buffers.x = 1 from self.buffers.x = 0, when the same MPI rank uses the 

239 variable with step 1, it will still carry the value of self.buffers.x = 1, equivalent to a send from the rank 

240 computing step 0 to the rank computing step 1. 

241 

242 However, you can only receive what somebody sent and in order to make sure that is true for the non MPI 

243 versions, we reset after each iteration so you cannot use this function to communicate backwards from the last 

244 step to the first one for instance. 

245 

246 This function is called both at the end of instantiating the controller, as well as after each iteration. 

247 

248 Args: 

249 controller (pySDC.Controller): The controller 

250 

251 Returns: 

252 None 

253 """ 

254 pass 

255 

256 def pre_iteration_processing(self, controller, S, **kwargs): 

257 """ 

258 Do whatever you want to before each iteration here. 

259 

260 Args: 

261 controller (pySDC.Controller): The controller 

262 S (pySDC.Step): The current step 

263 

264 Returns: 

265 None 

266 """ 

267 pass 

268 

269 def post_iteration_processing(self, controller, S, **kwargs): 

270 """ 

271 Do whatever you want to after each iteration here. 

272 

273 Args: 

274 controller (pySDC.Controller): The controller 

275 S (pySDC.Step): The current step 

276 

277 Returns: 

278 None 

279 """ 

280 pass 

281 

282 def post_step_processing(self, controller, S, **kwargs): 

283 """ 

284 Do whatever you want to after each step here. 

285 

286 Args: 

287 controller (pySDC.Controller): The controller 

288 S (pySDC.Step): The current step 

289 

290 Returns: 

291 None 

292 """ 

293 pass 

294 

295 def post_run_processing(self, controller, S, **kwargs): 

296 """ 

297 Do whatever you want to after the run here. 

298 

299 Args: 

300 controller (pySDC.Controller): The controller 

301 S (pySDC.Step): The current step 

302 

303 Returns: 

304 None 

305 """ 

306 pass 

307 

308 def prepare_next_block(self, controller, S, size, time, Tend, **kwargs): 

309 """ 

310 Prepare stuff like spreading step sizes or whatever. 

311 

312 Args: 

313 controller (pySDC.Controller): The controller 

314 S (pySDC.Step): The current step 

315 size (int): The number of ranks 

316 time (float): The current time will be list in nonMPI controller implementation 

317 Tend (float): The final time 

318 

319 Returns: 

320 None 

321 """ 

322 pass 

323 

324 def convergence_control(self, controller, S, **kwargs): 

325 """ 

326 Call all the functions related to convergence control. 

327 This is called in `it_check` in the controller after every iteration just after `post_iteration_processing`. 

328 Args: 

329 controller (pySDC.Controller): The controller 

330 S (pySDC.Step): The current step 

331 

332 Returns: 

333 None 

334 """ 

335 

336 self.get_new_step_size(controller, S, **kwargs) 

337 self.determine_restart(controller, S, **kwargs) 

338 self.check_iteration_status(controller, S, **kwargs) 

339 

340 return None 

341 

342 def post_spread_processing(self, controller, S, **kwargs): 

343 """ 

344 This function is called at the end of the `SPREAD` stage in the controller 

345 

346 Args: 

347 controller (pySDC.Controller): The controller 

348 S (pySDC.Step): The current step 

349 """ 

350 pass 

351 

352 def send(self, comm, dest, data, blocking=False, **kwargs): 

353 """ 

354 Send data to a different rank 

355 

356 Args: 

357 comm (mpi4py.MPI.Intracomm): Communicator 

358 dest (int): The target rank 

359 data: Data to be sent 

360 blocking (bool): Whether the communication is blocking or not 

361 

362 Returns: 

363 request handle of the communication 

364 """ 

365 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

366 

367 # log what's happening for debug purposes 

368 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}sends to step {dest} with tag {kwargs["tag"]}') 

369 

370 if blocking: 

371 req = comm.send(data, dest=dest, **kwargs) 

372 else: 

373 req = comm.isend(data, dest=dest, **kwargs) 

374 

375 return req 

376 

377 def recv(self, comm, source, **kwargs): 

378 """ 

379 Receive some data 

380 

381 Args: 

382 comm (mpi4py.MPI.Intracomm): Communicator 

383 source (int): Where to look for receiving 

384 

385 Returns: 

386 whatever has been received 

387 """ 

388 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

389 

390 # log what's happening for debug purposes 

391 self.logger.debug(f'Step {comm.rank} receives from step {source} with tag {kwargs["tag"]}') 

392 

393 data = comm.recv(source=source, **kwargs) 

394 

395 return data 

396 

397 def Send(self, comm, dest, buffer, blocking=False, **kwargs): 

398 """ 

399 Send data to a different rank 

400 

401 Args: 

402 comm (mpi4py.MPI.Intracomm): Communicator 

403 dest (int): The target rank 

404 buffer: Buffer for the data 

405 blocking (bool): Whether the communication is blocking or not 

406 

407 Returns: 

408 request handle of the communication 

409 """ 

410 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

411 

412 # log what's happening for debug purposes 

413 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}Sends to step {dest} with tag {kwargs["tag"]}') 

414 

415 if blocking: 

416 req = comm.Send(buffer, dest=dest, **kwargs) 

417 else: 

418 req = comm.Isend(buffer, dest=dest, **kwargs) 

419 

420 return req 

421 

422 def Recv(self, comm, source, buffer, **kwargs): 

423 """ 

424 Receive some data 

425 

426 Args: 

427 comm (mpi4py.MPI.Intracomm): Communicator 

428 source (int): Where to look for receiving 

429 

430 Returns: 

431 whatever has been received 

432 """ 

433 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

434 

435 # log what's happening for debug purposes 

436 self.logger.debug(f'Step {comm.rank} Receives from step {source} with tag {kwargs["tag"]}') 

437 

438 data = comm.Recv(buffer, source=source, **kwargs) 

439 

440 return data 

441 

442 def add_status_variable_to_step(self, key, value=None): 

443 if type(self.controller).__name__ == 'controller_MPI': 

444 steps = [self.controller.S] 

445 else: 

446 steps = self.controller.MS 

447 

448 steps[0].status.add_attr(key) 

449 

450 if value is not None: 

451 self.set_step_status_variable(key, value) 

452 

453 def set_step_status_variable(self, key, value): 

454 if type(self.controller).__name__ == 'controller_MPI': 

455 steps = [self.controller.S] 

456 else: 

457 steps = self.controller.MS 

458 

459 for S in steps: 

460 S.status.__dict__[key] = value 

461 

462 def add_status_variable_to_level(self, key, value=None): 

463 if type(self.controller).__name__ == 'controller_MPI': 

464 steps = [self.controller.S] 

465 else: 

466 steps = self.controller.MS 

467 

468 steps[0].levels[0].status.add_attr(key) 

469 

470 if value is not None: 

471 self.set_level_status_variable(key, value) 

472 

473 def set_level_status_variable(self, key, value): 

474 if type(self.controller).__name__ == 'controller_MPI': 

475 steps = [self.controller.S] 

476 else: 

477 steps = self.controller.MS 

478 

479 for S in steps: 

480 for L in S.levels: 

481 L.status.__dict__[key] = value