Coverage for pySDC/core/ConvergenceController.py: 89%

99 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import logging 

2from pySDC.helpers.pysdc_helper import FrozenClass 

3 

4 

5# short helper class to add params as attributes 

6class Pars(FrozenClass): 

7 def __init__(self, params): 

8 self.control_order = 0 # integer that determines the order in which the convergence controllers are called 

9 self.useMPI = None # depends on the controller 

10 

11 for k, v in params.items(): 

12 setattr(self, k, v) 

13 

14 self._freeze() 

15 

16 

17# short helper class to store status variables 

18class Status(FrozenClass): 

19 """ 

20 Initialize status variables with None, since at the time of instantiation of the convergence controllers, not all 

21 relevant information about the controller are known. 

22 """ 

23 

24 def __init__(self, status_variabes): 

25 [setattr(self, key, None) for key in status_variabes] 

26 

27 self._freeze() 

28 

29 

30class ConvergenceController(object): 

31 """ 

32 Base abstract class for convergence controller, which is plugged into the controller to determine the iteration 

33 count and time step size. 

34 """ 

35 

36 def __init__(self, controller, params, description, **kwargs): 

37 """ 

38 Initialization routine 

39 

40 Args: 

41 controller (pySDC.Controller): The controller 

42 params (dict): The params passed for this specific convergence controller 

43 description (dict): The description object used to instantiate the controller 

44 """ 

45 self.params = Pars(self.setup(controller, params, description)) 

46 params_ok, msg = self.check_parameters(controller, params, description) 

47 assert params_ok, f'{type(self).__name__} -- {msg}' 

48 self.dependencies(controller, description) 

49 self.logger = logging.getLogger(f"{type(self).__name__}") 

50 

51 if self.params.useMPI: 

52 self.prepare_MPI_datatypes() 

53 

54 def prepare_MPI_logical_operations(self): 

55 """ 

56 Prepare MPI logical operations so we don't need to import mpi4py all the time 

57 """ 

58 from mpi4py import MPI 

59 

60 self.MPI_LAND = MPI.LAND 

61 self.MPI_LOR = MPI.LOR 

62 

63 def prepare_MPI_datatypes(self): 

64 """ 

65 Prepare MPI datatypes so we don't need to import mpi4py all the time 

66 """ 

67 from mpi4py import MPI 

68 

69 self.MPI_INT = MPI.INT 

70 self.MPI_DOUBLE = MPI.DOUBLE 

71 self.MPI_BOOL = MPI.BOOL 

72 

73 def log(self, msg, S, level=15, **kwargs): 

74 """ 

75 Shortcut that has a default level for the logger. 15 is above debug but below info. 

76 

77 Args: 

78 msg (str): Message you want to log 

79 S (pySDC.step): The current step 

80 level (int): the level passed to the logger 

81 

82 Returns: 

83 None 

84 """ 

85 self.logger.log(level, f'Process {S.status.slot:2d} on time {S.time:.6f} - {msg}') 

86 return None 

87 

88 def debug(self, msg, S, **kwargs): 

89 """ 

90 Shortcut to pass messages at debug level to the logger. 

91 

92 Args: 

93 msg (str): Message you want to log 

94 S (pySDC.step): The current step 

95 

96 Returns: 

97 None 

98 """ 

99 self.log(msg=msg, S=S, level=10, **kwargs) 

100 return None 

101 

102 def setup(self, controller, params, description, **kwargs): 

103 """ 

104 Setup various variables that only need to be set once in the beginning. 

105 If the convergence controller is added automatically, you can give it params by adding it manually. 

106 It will be instantiated only once with the manually supplied parameters overriding automatically added 

107 parameters. 

108 

109 This function scans the convergence controllers supplied to the description object for instances of itself. 

110 This corresponds to the convergence controller being added manually by the user. If something is found, this 

111 function will then return a composite dictionary from the `params` passed to this function as well as the 

112 `params` passed manually, with priority to manually added parameters. If you added the convergence controller 

113 manually, that is of course the same and nothing happens. If, on the other hand, the convergence controller was 

114 added automatically, the `params` passed here will come from whatever added it and you can now override 

115 parameters by adding the convergence controller manually. 

116 This relies on children classes to return a composite dictionary from their defaults and from the result of this 

117 function, so you should write 

118 ``` 

119 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

120 ``` 

121 when overloading this method in a child class, with `defaults` a dictionary containing default parameters. 

122 

123 Args: 

124 controller (pySDC.Controller): The controller 

125 params (dict): The params passed for this specific convergence controller 

126 description (dict): The description object used to instantiate the controller 

127 

128 Returns: 

129 (dict): The updated params dictionary after setup 

130 """ 

131 # allow to change parameters by adding the convergence controller manually 

132 return {**params, **description.get('convergence_controllers', {}).get(type(self), {})} 

133 

134 def dependencies(self, controller, description, **kwargs): 

135 """ 

136 Load dependencies on other convergence controllers here. 

137 

138 Args: 

139 controller (pySDC.Controller): The controller 

140 description (dict): The description object used to instantiate the controller 

141 

142 Returns: 

143 None 

144 """ 

145 pass 

146 

147 def check_parameters(self, controller, params, description, **kwargs): 

148 """ 

149 Check whether parameters are compatible with whatever assumptions went into the step size functions etc. 

150 

151 Args: 

152 controller (pySDC.Controller): The controller 

153 params (dict): The params passed for this specific convergence controller 

154 description (dict): The description object used to instantiate the controller 

155 

156 Returns: 

157 bool: Whether the parameters are compatible 

158 str: The error message 

159 """ 

160 return True, "" 

161 

162 def check_iteration_status(self, controller, S, **kwargs): 

163 """ 

164 Determine whether to keep iterating or not in this function. 

165 

166 Args: 

167 controller (pySDC.Controller): The controller 

168 S (pySDC.Step): The current step 

169 

170 Returns: 

171 None 

172 """ 

173 pass 

174 

175 def get_new_step_size(self, controller, S, **kwargs): 

176 """ 

177 This function allows to set a step size with arbitrary criteria. 

178 Make sure to give an order to the convergence controller by setting the `control_order` variable in the params. 

179 This variable is an integer and you can see what the current order is by using 

180 `controller.print_convergence_controllers()`. 

181 

182 Args: 

183 controller (pySDC.Controller): The controller 

184 S (pySDC.Step): The current step 

185 

186 Returns: 

187 None 

188 """ 

189 pass 

190 

191 def determine_restart(self, controller, S, **kwargs): 

192 """ 

193 Determine for each step separately if it wants to be restarted for whatever reason. 

194 

195 Args: 

196 controller (pySDC.Controller): The controller 

197 S (pySDC.Step): The current step 

198 

199 Returns: 

200 None 

201 """ 

202 pass 

203 

204 def reset_status_variables(self, controller, **kwargs): 

205 """ 

206 Reset status variables. 

207 This is called in the `restart_block` function. 

208 Args: 

209 controller (pySDC.Controller): The controller 

210 

211 Returns: 

212 None 

213 """ 

214 return None 

215 

216 def setup_status_variables(self, controller, **kwargs): 

217 """ 

218 Setup status variables. 

219 This is not done at the time of instantiation, since the controller is not fully instantiated at that time and 

220 hence not all information are available. Instead, this function is called after the controller has been fully 

221 instantiated. 

222 

223 Args: 

224 controller (pySDC.Controller): The controller 

225 

226 Returns: 

227 None 

228 """ 

229 return None 

230 

231 def reset_buffers_nonMPI(self, controller, **kwargs): 

232 """ 

233 Buffers refer to variables used across multiple steps that are stored in the convergence controller classes to 

234 imitate communication in non MPI versions. These have to be reset in order to replicate availability of 

235 variables in MPI versions. 

236 

237 For instance, if step 0 sets self.buffers.x = 1 from self.buffers.x = 0, when the same MPI rank uses the 

238 variable with step 1, it will still carry the value of self.buffers.x = 1, equivalent to a send from the rank 

239 computing step 0 to the rank computing step 1. 

240 

241 However, you can only receive what somebody sent and in order to make sure that is true for the non MPI 

242 versions, we reset after each iteration so you cannot use this function to communicate backwards from the last 

243 step to the first one for instance. 

244 

245 This function is called both at the end of instantiating the controller, as well as after each iteration. 

246 

247 Args: 

248 controller (pySDC.Controller): The controller 

249 

250 Returns: 

251 None 

252 """ 

253 pass 

254 

255 def pre_iteration_processing(self, controller, S, **kwargs): 

256 """ 

257 Do whatever you want to before each iteration here. 

258 

259 Args: 

260 controller (pySDC.Controller): The controller 

261 S (pySDC.Step): The current step 

262 

263 Returns: 

264 None 

265 """ 

266 pass 

267 

268 def post_iteration_processing(self, controller, S, **kwargs): 

269 """ 

270 Do whatever you want to after each iteration here. 

271 

272 Args: 

273 controller (pySDC.Controller): The controller 

274 S (pySDC.Step): The current step 

275 

276 Returns: 

277 None 

278 """ 

279 pass 

280 

281 def post_step_processing(self, controller, S, **kwargs): 

282 """ 

283 Do whatever you want to after each step here. 

284 

285 Args: 

286 controller (pySDC.Controller): The controller 

287 S (pySDC.Step): The current step 

288 

289 Returns: 

290 None 

291 """ 

292 pass 

293 

294 def prepare_next_block(self, controller, S, size, time, Tend, **kwargs): 

295 """ 

296 Prepare stuff like spreading step sizes or whatever. 

297 

298 Args: 

299 controller (pySDC.Controller): The controller 

300 S (pySDC.Step): The current step 

301 size (int): The number of ranks 

302 time (float): The current time will be list in nonMPI controller implementation 

303 Tend (float): The final time 

304 

305 Returns: 

306 None 

307 """ 

308 pass 

309 

310 def convergence_control(self, controller, S, **kwargs): 

311 """ 

312 Call all the functions related to convergence control. 

313 This is called in `it_check` in the controller after every iteration just after `post_iteration_processing`. 

314 Args: 

315 controller (pySDC.Controller): The controller 

316 S (pySDC.Step): The current step 

317 

318 Returns: 

319 None 

320 """ 

321 

322 self.get_new_step_size(controller, S, **kwargs) 

323 self.determine_restart(controller, S, **kwargs) 

324 self.check_iteration_status(controller, S, **kwargs) 

325 

326 return None 

327 

328 def post_spread_processing(self, controller, S, **kwargs): 

329 """ 

330 This function is called at the end of the `SPREAD` stage in the controller 

331 

332 Args: 

333 controller (pySDC.Controller): The controller 

334 S (pySDC.Step): The current step 

335 """ 

336 pass 

337 

338 def send(self, comm, dest, data, blocking=False, **kwargs): 

339 """ 

340 Send data to a different rank 

341 

342 Args: 

343 comm (mpi4py.MPI.Intracomm): Communicator 

344 dest (int): The target rank 

345 data: Data to be sent 

346 blocking (bool): Whether the communication is blocking or not 

347 

348 Returns: 

349 request handle of the communication 

350 """ 

351 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

352 

353 # log what's happening for debug purposes 

354 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}sends to step {dest} with tag {kwargs["tag"]}') 

355 

356 if blocking: 

357 req = comm.send(data, dest=dest, **kwargs) 

358 else: 

359 req = comm.isend(data, dest=dest, **kwargs) 

360 

361 return req 

362 

363 def recv(self, comm, source, **kwargs): 

364 """ 

365 Receive some data 

366 

367 Args: 

368 comm (mpi4py.MPI.Intracomm): Communicator 

369 source (int): Where to look for receiving 

370 

371 Returns: 

372 whatever has been received 

373 """ 

374 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

375 

376 # log what's happening for debug purposes 

377 self.logger.debug(f'Step {comm.rank} receives from step {source} with tag {kwargs["tag"]}') 

378 

379 data = comm.recv(source=source, **kwargs) 

380 

381 return data 

382 

383 def Send(self, comm, dest, buffer, blocking=False, **kwargs): 

384 """ 

385 Send data to a different rank 

386 

387 Args: 

388 comm (mpi4py.MPI.Intracomm): Communicator 

389 dest (int): The target rank 

390 buffer: Buffer for the data 

391 blocking (bool): Whether the communication is blocking or not 

392 

393 Returns: 

394 request handle of the communication 

395 """ 

396 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

397 

398 # log what's happening for debug purposes 

399 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}Sends to step {dest} with tag {kwargs["tag"]}') 

400 

401 if blocking: 

402 req = comm.Send(buffer, dest=dest, **kwargs) 

403 else: 

404 req = comm.Isend(buffer, dest=dest, **kwargs) 

405 

406 return req 

407 

408 def Recv(self, comm, source, buffer, **kwargs): 

409 """ 

410 Receive some data 

411 

412 Args: 

413 comm (mpi4py.MPI.Intracomm): Communicator 

414 source (int): Where to look for receiving 

415 

416 Returns: 

417 whatever has been received 

418 """ 

419 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

420 

421 # log what's happening for debug purposes 

422 self.logger.debug(f'Step {comm.rank} Receives from step {source} with tag {kwargs["tag"]}') 

423 

424 data = comm.Recv(buffer, source=source, **kwargs) 

425 

426 return data 

427 

428 def reset_variable(self, controller, name, MPI=False, place=None, where=None, init=None): 

429 """ 

430 Utility function for resetting variables. This function will call the `add_variable` function with all the same 

431 arguments, but with `allow_overwrite = True`. 

432 

433 Args: 

434 controller (pySDC.Controller): The controller 

435 name (str): The name of the variable 

436 MPI (bool): Whether to use MPI controller 

437 place (object): The object you want to reset the variable of 

438 where (list): List of strings containing a path to where you want to reset the variable 

439 init: Initial value of the variable 

440 

441 Returns: 

442 None 

443 """ 

444 self.add_variable(controller, name, MPI, place, where, init, allow_overwrite=True) 

445 

446 def add_variable(self, controller, name, MPI=False, place=None, where=None, init=None, allow_overwrite=False): 

447 """ 

448 Add a variable to a frozen class. 

449 

450 This function goes through the path to the destination of the variable recursively and adds it to all instances 

451 that are possible in the path. For example, giving `where = ["MS", "levels", "status"]` will result in adding a 

452 variable to the status object of all levels of all steps of the controller. 

453 

454 Part of the functionality of the frozen class is to separate initialization and setting of variables. By 

455 enforcing this, you can make sure not to overwrite already existing variables. Since this function is called 

456 outside of the `__init__` function of the status objects, this can otherwise lead to bugs that are hard to find. 

457 For this reason, you need to specifically set `allow_overwrite = True` if you want to forgo the check if the 

458 variable already exists. This can be useful when resetting variables between steps, but make sure to set it to 

459 `allow_overwrite = False` the first time you add a variable. 

460 

461 Args: 

462 controller (pySDC.Controller): The controller 

463 name (str): The name of the variable 

464 MPI (bool): Whether to use MPI controller 

465 place (object): The object you want to add the variable to 

466 where (list): List of strings containing a path to where you want to add the variable 

467 init: Initial value of the variable 

468 allow_overwrite (bool): Allow overwriting the variables if they already exist or raise an exception 

469 

470 Returns: 

471 None 

472 """ 

473 where = ["S" if MPI else "MS", "levels", "status"] if where is None else where 

474 place = controller if place is None else place 

475 

476 # check if we have arrived at the end of the path to the variable 

477 if len(where) == 0: 

478 variable_exitsts = name in place.__dict__.keys() 

479 # check if the variable already exists and raise an error in case we are about to introduce a bug 

480 if not allow_overwrite and variable_exitsts: 

481 raise ValueError(f"Key \"{name}\" already exists in {place}! Please rename the variable in {self}") 

482 # if we allow overwriting, but the variable does not exist already, we are violating the intended purpose 

483 # of this function, so we also raise an error if someone should be so mad as to attempt this 

484 elif allow_overwrite and not variable_exitsts: 

485 raise ValueError(f"Key \"{name}\" is supposed to be overwritten in {place}, but it does not exist!") 

486 

487 # actually add or overwrite the variable 

488 place.__dict__[name] = init 

489 

490 # follow the path to the final destination recursively 

491 else: 

492 # get all possible new places to continue the path 

493 new_places = place.__dict__[where[0]] 

494 

495 # continue all possible paths 

496 if type(new_places) == list: 

497 # loop through all possibilities 

498 for new_place in new_places: 

499 self.add_variable( 

500 controller, 

501 name, 

502 MPI=MPI, 

503 place=new_place, 

504 where=where[1:], 

505 init=init, 

506 allow_overwrite=allow_overwrite, 

507 ) 

508 else: 

509 # go to the only possible possibility 

510 self.add_variable( 

511 controller, 

512 name, 

513 MPI=MPI, 

514 place=new_places, 

515 where=where[1:], 

516 init=init, 

517 allow_overwrite=allow_overwrite, 

518 )