Coverage for pySDC / core / convergence_controller.py: 83%

117 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-12 11:13 +0000

1import logging 

2from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING 

3from pySDC.helpers.pysdc_helper import FrozenClass 

4 

5if TYPE_CHECKING: 

6 from pySDC.core.controller import Controller 

7 from pySDC.core.step import Step 

8 

9 

10# short helper class to add params as attributes 

11class Pars(FrozenClass): 

12 def __init__(self, params: Dict[str, Any]) -> None: 

13 self.control_order: int = 0 # integer that determines the order in which the convergence controllers are called 

14 self.useMPI: Optional[bool] = None # depends on the controller 

15 

16 for k, v in params.items(): 

17 setattr(self, k, v) 

18 

19 self._freeze() 

20 

21 

22# short helper class to store status variables 

23class Status(FrozenClass): 

24 """ 

25 Initialize status variables with None, since at the time of instantiation of the convergence controllers, not all 

26 relevant information about the controller are known. 

27 """ 

28 

29 def __init__(self, status_variabes: List[str]) -> None: 

30 [setattr(self, key, None) for key in status_variabes] 

31 

32 self._freeze() 

33 

34 

35class ConvergenceController(object): 

36 """ 

37 Base abstract class for convergence controller, which is plugged into the controller to determine the iteration 

38 count and time step size. 

39 """ 

40 

41 def __init__( 

42 self, controller: 'Controller', params: Dict[str, Any], description: Dict[str, Any], **kwargs: Any 

43 ) -> None: 

44 """ 

45 Initialization routine 

46 

47 Args: 

48 controller (pySDC.Controller): The controller 

49 params (dict): The params passed for this specific convergence controller 

50 description (dict): The description object used to instantiate the controller 

51 """ 

52 self.controller: 'Controller' = controller 

53 self.params: Pars = Pars(self.setup(controller, params, description)) 

54 params_ok, msg = self.check_parameters(controller, params, description) 

55 assert params_ok, f'{type(self).__name__} -- {msg}' 

56 self.dependencies(controller, description) 

57 self.logger: logging.Logger = logging.getLogger(f"{type(self).__name__}") 

58 

59 if self.params.useMPI: 

60 self.prepare_MPI_datatypes() 

61 

62 def prepare_MPI_logical_operations(self) -> None: 

63 """ 

64 Prepare MPI logical operations so we don't need to import mpi4py all the time 

65 """ 

66 from mpi4py import MPI 

67 

68 self.MPI_LAND = MPI.LAND 

69 self.MPI_LOR = MPI.LOR 

70 

71 def prepare_MPI_datatypes(self) -> None: 

72 """ 

73 Prepare MPI datatypes so we don't need to import mpi4py all the time 

74 """ 

75 from mpi4py import MPI 

76 

77 self.MPI_INT = MPI.INT 

78 self.MPI_DOUBLE = MPI.DOUBLE 

79 self.MPI_BOOL = MPI.BOOL 

80 

81 def log(self, msg: str, S: 'Step', level: int = 15, **kwargs: Any) -> None: 

82 """ 

83 Shortcut that has a default level for the logger. 15 is above debug but below info. 

84 

85 Args: 

86 msg (str): Message you want to log 

87 S (pySDC.step): The current step 

88 level (int): the level passed to the logger 

89 

90 Returns: 

91 None 

92 """ 

93 self.logger.log(level, f'Process {S.status.slot:2d} on time {S.time:.6f} - {msg}') 

94 return None 

95 

96 def debug(self, msg: str, S: 'Step', **kwargs: Any) -> None: 

97 """ 

98 Shortcut to pass messages at debug level to the logger. 

99 

100 Args: 

101 msg (str): Message you want to log 

102 S (pySDC.step): The current step 

103 

104 Returns: 

105 None 

106 """ 

107 self.log(msg=msg, S=S, level=10, **kwargs) 

108 return None 

109 

110 def setup( 

111 self, controller: 'Controller', params: Dict[str, Any], description: Dict[str, Any], **kwargs: Any 

112 ) -> Dict[str, Any]: 

113 """ 

114 Setup various variables that only need to be set once in the beginning. 

115 If the convergence controller is added automatically, you can give it params by adding it manually. 

116 It will be instantiated only once with the manually supplied parameters overriding automatically added 

117 parameters. 

118 

119 This function scans the convergence controllers supplied to the description object for instances of itself. 

120 This corresponds to the convergence controller being added manually by the user. If something is found, this 

121 function will then return a composite dictionary from the `params` passed to this function as well as the 

122 `params` passed manually, with priority to manually added parameters. If you added the convergence controller 

123 manually, that is of course the same and nothing happens. If, on the other hand, the convergence controller was 

124 added automatically, the `params` passed here will come from whatever added it and you can now override 

125 parameters by adding the convergence controller manually. 

126 This relies on children classes to return a composite dictionary from their defaults and from the result of this 

127 function, so you should write 

128 ``` 

129 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

130 ``` 

131 when overloading this method in a child class, with `defaults` a dictionary containing default parameters. 

132 

133 Args: 

134 controller (pySDC.Controller): The controller 

135 params (dict): The params passed for this specific convergence controller 

136 description (dict): The description object used to instantiate the controller 

137 

138 Returns: 

139 (dict): The updated params dictionary after setup 

140 """ 

141 # allow to change parameters by adding the convergence controller manually 

142 return {**params, **description.get('convergence_controllers', {}).get(type(self), {})} 

143 

144 def dependencies(self, controller: 'Controller', description: Dict[str, Any], **kwargs: Any) -> None: 

145 """ 

146 Load dependencies on other convergence controllers here. 

147 

148 Args: 

149 controller (pySDC.Controller): The controller 

150 description (dict): The description object used to instantiate the controller 

151 

152 Returns: 

153 None 

154 """ 

155 pass 

156 

157 def check_parameters( 

158 self, controller: 'Controller', params: Dict[str, Any], description: Dict[str, Any], **kwargs: Any 

159 ) -> Tuple[bool, str]: 

160 """ 

161 Check whether parameters are compatible with whatever assumptions went into the step size functions etc. 

162 

163 Args: 

164 controller (pySDC.Controller): The controller 

165 params (dict): The params passed for this specific convergence controller 

166 description (dict): The description object used to instantiate the controller 

167 

168 Returns: 

169 bool: Whether the parameters are compatible 

170 str: The error message 

171 """ 

172 return True, "" 

173 

174 def check_iteration_status(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

175 """ 

176 Determine whether to keep iterating or not in this function. 

177 

178 Args: 

179 controller (pySDC.Controller): The controller 

180 S (pySDC.Step): The current step 

181 

182 Returns: 

183 None 

184 """ 

185 pass 

186 

187 def get_new_step_size(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

188 """ 

189 This function allows to set a step size with arbitrary criteria. 

190 Make sure to give an order to the convergence controller by setting the `control_order` variable in the params. 

191 This variable is an integer and you can see what the current order is by using 

192 `controller.print_convergence_controllers()`. 

193 

194 Args: 

195 controller (pySDC.Controller): The controller 

196 S (pySDC.Step): The current step 

197 

198 Returns: 

199 None 

200 """ 

201 pass 

202 

203 def determine_restart(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

204 """ 

205 Determine for each step separately if it wants to be restarted for whatever reason. 

206 

207 Args: 

208 controller (pySDC.Controller): The controller 

209 S (pySDC.Step): The current step 

210 

211 Returns: 

212 None 

213 """ 

214 pass 

215 

216 def reset_status_variables(self, controller: 'Controller', **kwargs: Any) -> None: 

217 """ 

218 Reset status variables. 

219 This is called in the `restart_block` function. 

220 Args: 

221 controller (pySDC.Controller): The controller 

222 

223 Returns: 

224 None 

225 """ 

226 return None 

227 

228 def setup_status_variables(self, controller: 'Controller', **kwargs: Any) -> None: 

229 """ 

230 Setup status variables. 

231 This is not done at the time of instantiation, since the controller is not fully instantiated at that time and 

232 hence not all information are available. Instead, this function is called after the controller has been fully 

233 instantiated. 

234 

235 Args: 

236 controller (pySDC.Controller): The controller 

237 

238 Returns: 

239 None 

240 """ 

241 return None 

242 

243 def reset_buffers_nonMPI(self, controller: 'Controller', **kwargs: Any) -> None: 

244 """ 

245 Buffers refer to variables used across multiple steps that are stored in the convergence controller classes to 

246 imitate communication in non MPI versions. These have to be reset in order to replicate availability of 

247 variables in MPI versions. 

248 

249 For instance, if step 0 sets self.buffers.x = 1 from self.buffers.x = 0, when the same MPI rank uses the 

250 variable with step 1, it will still carry the value of self.buffers.x = 1, equivalent to a send from the rank 

251 computing step 0 to the rank computing step 1. 

252 

253 However, you can only receive what somebody sent and in order to make sure that is true for the non MPI 

254 versions, we reset after each iteration so you cannot use this function to communicate backwards from the last 

255 step to the first one for instance. 

256 

257 This function is called both at the end of instantiating the controller, as well as after each iteration. 

258 

259 Args: 

260 controller (pySDC.Controller): The controller 

261 

262 Returns: 

263 None 

264 """ 

265 pass 

266 

267 def pre_iteration_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

268 """ 

269 Do whatever you want to before each iteration here. 

270 

271 Args: 

272 controller (pySDC.Controller): The controller 

273 S (pySDC.Step): The current step 

274 

275 Returns: 

276 None 

277 """ 

278 pass 

279 

280 def post_iteration_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

281 """ 

282 Do whatever you want to after each iteration here. 

283 

284 Args: 

285 controller (pySDC.Controller): The controller 

286 S (pySDC.Step): The current step 

287 

288 Returns: 

289 None 

290 """ 

291 pass 

292 

293 def post_step_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

294 """ 

295 Do whatever you want to after each step here. 

296 

297 Args: 

298 controller (pySDC.Controller): The controller 

299 S (pySDC.Step): The current step 

300 

301 Returns: 

302 None 

303 """ 

304 pass 

305 

306 def post_run_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

307 """ 

308 Do whatever you want to after the run here. 

309 

310 Args: 

311 controller (pySDC.Controller): The controller 

312 S (pySDC.Step): The current step 

313 

314 Returns: 

315 None 

316 """ 

317 pass 

318 

319 def prepare_next_block( 

320 self, controller: 'Controller', S: 'Step', size: int, time: Any, Tend: float, **kwargs: Any 

321 ) -> None: 

322 """ 

323 Prepare stuff like spreading step sizes or whatever. 

324 

325 Args: 

326 controller (pySDC.Controller): The controller 

327 S (pySDC.Step): The current step 

328 size (int): The number of ranks 

329 time (float): The current time will be list in nonMPI controller implementation 

330 Tend (float): The final time 

331 

332 Returns: 

333 None 

334 """ 

335 pass 

336 

337 def convergence_control(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

338 """ 

339 Call all the functions related to convergence control. 

340 This is called in `it_check` in the controller after every iteration just after `post_iteration_processing`. 

341 Args: 

342 controller (pySDC.Controller): The controller 

343 S (pySDC.Step): The current step 

344 

345 Returns: 

346 None 

347 """ 

348 

349 self.get_new_step_size(controller, S, **kwargs) 

350 self.determine_restart(controller, S, **kwargs) 

351 self.check_iteration_status(controller, S, **kwargs) 

352 

353 return None 

354 

355 def post_spread_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None: 

356 """ 

357 This function is called at the end of the `SPREAD` stage in the controller 

358 

359 Args: 

360 controller (pySDC.Controller): The controller 

361 S (pySDC.Step): The current step 

362 """ 

363 pass 

364 

365 def send(self, comm, dest, data, blocking=False, **kwargs): 

366 """ 

367 Send data to a different rank 

368 

369 Args: 

370 comm (mpi4py.MPI.Intracomm): Communicator 

371 dest (int): The target rank 

372 data: Data to be sent 

373 blocking (bool): Whether the communication is blocking or not 

374 

375 Returns: 

376 request handle of the communication 

377 """ 

378 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

379 

380 # log what's happening for debug purposes 

381 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}sends to step {dest} with tag {kwargs["tag"]}') 

382 

383 if blocking: 

384 req = comm.send(data, dest=dest, **kwargs) 

385 else: 

386 req = comm.isend(data, dest=dest, **kwargs) 

387 

388 return req 

389 

390 def recv(self, comm, source, **kwargs): 

391 """ 

392 Receive some data 

393 

394 Args: 

395 comm (mpi4py.MPI.Intracomm): Communicator 

396 source (int): Where to look for receiving 

397 

398 Returns: 

399 whatever has been received 

400 """ 

401 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

402 

403 # log what's happening for debug purposes 

404 self.logger.debug(f'Step {comm.rank} receives from step {source} with tag {kwargs["tag"]}') 

405 

406 data = comm.recv(source=source, **kwargs) 

407 

408 return data 

409 

410 def Send(self, comm, dest, buffer, blocking=False, **kwargs): 

411 """ 

412 Send data to a different rank 

413 

414 Args: 

415 comm (mpi4py.MPI.Intracomm): Communicator 

416 dest (int): The target rank 

417 buffer: Buffer for the data 

418 blocking (bool): Whether the communication is blocking or not 

419 

420 Returns: 

421 request handle of the communication 

422 """ 

423 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

424 

425 # log what's happening for debug purposes 

426 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}Sends to step {dest} with tag {kwargs["tag"]}') 

427 

428 if blocking: 

429 req = comm.Send(buffer, dest=dest, **kwargs) 

430 else: 

431 req = comm.Isend(buffer, dest=dest, **kwargs) 

432 

433 return req 

434 

435 def Recv(self, comm, source, buffer, **kwargs): 

436 """ 

437 Receive some data 

438 

439 Args: 

440 comm (mpi4py.MPI.Intracomm): Communicator 

441 source (int): Where to look for receiving 

442 

443 Returns: 

444 whatever has been received 

445 """ 

446 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order)) 

447 

448 # log what's happening for debug purposes 

449 self.logger.debug(f'Step {comm.rank} Receives from step {source} with tag {kwargs["tag"]}') 

450 

451 data = comm.Recv(buffer, source=source, **kwargs) 

452 

453 return data 

454 

455 def add_status_variable_to_step(self, key, value=None): 

456 if type(self.controller).__name__ == 'controller_MPI': 

457 steps = [self.controller.S] 

458 else: 

459 steps = self.controller.MS 

460 

461 steps[0].status.add_attr(key) 

462 

463 if value is not None: 

464 self.set_step_status_variable(key, value) 

465 

466 def set_step_status_variable(self, key, value): 

467 if type(self.controller).__name__ == 'controller_MPI': 

468 steps = [self.controller.S] 

469 else: 

470 steps = self.controller.MS 

471 

472 for S in steps: 

473 S.status.__dict__[key] = value 

474 

475 def add_status_variable_to_level(self, key, value=None): 

476 if type(self.controller).__name__ == 'controller_MPI': 

477 steps = [self.controller.S] 

478 else: 

479 steps = self.controller.MS 

480 

481 steps[0].levels[0].status.add_attr(key) 

482 

483 if value is not None: 

484 self.set_level_status_variable(key, value) 

485 

486 def set_level_status_variable(self, key, value): 

487 if type(self.controller).__name__ == 'controller_MPI': 

488 steps = [self.controller.S] 

489 else: 

490 steps = self.controller.MS 

491 

492 for S in steps: 

493 for L in S.levels: 

494 L.status.__dict__[key] = value