Coverage for pySDC / core / convergence_controller.py: 83%
117 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-12 11:13 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-12 11:13 +0000
1import logging
2from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
3from pySDC.helpers.pysdc_helper import FrozenClass
5if TYPE_CHECKING:
6 from pySDC.core.controller import Controller
7 from pySDC.core.step import Step
10# short helper class to add params as attributes
11class Pars(FrozenClass):
12 def __init__(self, params: Dict[str, Any]) -> None:
13 self.control_order: int = 0 # integer that determines the order in which the convergence controllers are called
14 self.useMPI: Optional[bool] = None # depends on the controller
16 for k, v in params.items():
17 setattr(self, k, v)
19 self._freeze()
22# short helper class to store status variables
23class Status(FrozenClass):
24 """
25 Initialize status variables with None, since at the time of instantiation of the convergence controllers, not all
26 relevant information about the controller are known.
27 """
29 def __init__(self, status_variabes: List[str]) -> None:
30 [setattr(self, key, None) for key in status_variabes]
32 self._freeze()
35class ConvergenceController(object):
36 """
37 Base abstract class for convergence controller, which is plugged into the controller to determine the iteration
38 count and time step size.
39 """
41 def __init__(
42 self, controller: 'Controller', params: Dict[str, Any], description: Dict[str, Any], **kwargs: Any
43 ) -> None:
44 """
45 Initialization routine
47 Args:
48 controller (pySDC.Controller): The controller
49 params (dict): The params passed for this specific convergence controller
50 description (dict): The description object used to instantiate the controller
51 """
52 self.controller: 'Controller' = controller
53 self.params: Pars = Pars(self.setup(controller, params, description))
54 params_ok, msg = self.check_parameters(controller, params, description)
55 assert params_ok, f'{type(self).__name__} -- {msg}'
56 self.dependencies(controller, description)
57 self.logger: logging.Logger = logging.getLogger(f"{type(self).__name__}")
59 if self.params.useMPI:
60 self.prepare_MPI_datatypes()
62 def prepare_MPI_logical_operations(self) -> None:
63 """
64 Prepare MPI logical operations so we don't need to import mpi4py all the time
65 """
66 from mpi4py import MPI
68 self.MPI_LAND = MPI.LAND
69 self.MPI_LOR = MPI.LOR
71 def prepare_MPI_datatypes(self) -> None:
72 """
73 Prepare MPI datatypes so we don't need to import mpi4py all the time
74 """
75 from mpi4py import MPI
77 self.MPI_INT = MPI.INT
78 self.MPI_DOUBLE = MPI.DOUBLE
79 self.MPI_BOOL = MPI.BOOL
81 def log(self, msg: str, S: 'Step', level: int = 15, **kwargs: Any) -> None:
82 """
83 Shortcut that has a default level for the logger. 15 is above debug but below info.
85 Args:
86 msg (str): Message you want to log
87 S (pySDC.step): The current step
88 level (int): the level passed to the logger
90 Returns:
91 None
92 """
93 self.logger.log(level, f'Process {S.status.slot:2d} on time {S.time:.6f} - {msg}')
94 return None
96 def debug(self, msg: str, S: 'Step', **kwargs: Any) -> None:
97 """
98 Shortcut to pass messages at debug level to the logger.
100 Args:
101 msg (str): Message you want to log
102 S (pySDC.step): The current step
104 Returns:
105 None
106 """
107 self.log(msg=msg, S=S, level=10, **kwargs)
108 return None
110 def setup(
111 self, controller: 'Controller', params: Dict[str, Any], description: Dict[str, Any], **kwargs: Any
112 ) -> Dict[str, Any]:
113 """
114 Setup various variables that only need to be set once in the beginning.
115 If the convergence controller is added automatically, you can give it params by adding it manually.
116 It will be instantiated only once with the manually supplied parameters overriding automatically added
117 parameters.
119 This function scans the convergence controllers supplied to the description object for instances of itself.
120 This corresponds to the convergence controller being added manually by the user. If something is found, this
121 function will then return a composite dictionary from the `params` passed to this function as well as the
122 `params` passed manually, with priority to manually added parameters. If you added the convergence controller
123 manually, that is of course the same and nothing happens. If, on the other hand, the convergence controller was
124 added automatically, the `params` passed here will come from whatever added it and you can now override
125 parameters by adding the convergence controller manually.
126 This relies on children classes to return a composite dictionary from their defaults and from the result of this
127 function, so you should write
128 ```
129 return {**defaults, **super().setup(controller, params, description, **kwargs)}
130 ```
131 when overloading this method in a child class, with `defaults` a dictionary containing default parameters.
133 Args:
134 controller (pySDC.Controller): The controller
135 params (dict): The params passed for this specific convergence controller
136 description (dict): The description object used to instantiate the controller
138 Returns:
139 (dict): The updated params dictionary after setup
140 """
141 # allow to change parameters by adding the convergence controller manually
142 return {**params, **description.get('convergence_controllers', {}).get(type(self), {})}
144 def dependencies(self, controller: 'Controller', description: Dict[str, Any], **kwargs: Any) -> None:
145 """
146 Load dependencies on other convergence controllers here.
148 Args:
149 controller (pySDC.Controller): The controller
150 description (dict): The description object used to instantiate the controller
152 Returns:
153 None
154 """
155 pass
157 def check_parameters(
158 self, controller: 'Controller', params: Dict[str, Any], description: Dict[str, Any], **kwargs: Any
159 ) -> Tuple[bool, str]:
160 """
161 Check whether parameters are compatible with whatever assumptions went into the step size functions etc.
163 Args:
164 controller (pySDC.Controller): The controller
165 params (dict): The params passed for this specific convergence controller
166 description (dict): The description object used to instantiate the controller
168 Returns:
169 bool: Whether the parameters are compatible
170 str: The error message
171 """
172 return True, ""
174 def check_iteration_status(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
175 """
176 Determine whether to keep iterating or not in this function.
178 Args:
179 controller (pySDC.Controller): The controller
180 S (pySDC.Step): The current step
182 Returns:
183 None
184 """
185 pass
187 def get_new_step_size(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
188 """
189 This function allows to set a step size with arbitrary criteria.
190 Make sure to give an order to the convergence controller by setting the `control_order` variable in the params.
191 This variable is an integer and you can see what the current order is by using
192 `controller.print_convergence_controllers()`.
194 Args:
195 controller (pySDC.Controller): The controller
196 S (pySDC.Step): The current step
198 Returns:
199 None
200 """
201 pass
203 def determine_restart(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
204 """
205 Determine for each step separately if it wants to be restarted for whatever reason.
207 Args:
208 controller (pySDC.Controller): The controller
209 S (pySDC.Step): The current step
211 Returns:
212 None
213 """
214 pass
216 def reset_status_variables(self, controller: 'Controller', **kwargs: Any) -> None:
217 """
218 Reset status variables.
219 This is called in the `restart_block` function.
220 Args:
221 controller (pySDC.Controller): The controller
223 Returns:
224 None
225 """
226 return None
228 def setup_status_variables(self, controller: 'Controller', **kwargs: Any) -> None:
229 """
230 Setup status variables.
231 This is not done at the time of instantiation, since the controller is not fully instantiated at that time and
232 hence not all information are available. Instead, this function is called after the controller has been fully
233 instantiated.
235 Args:
236 controller (pySDC.Controller): The controller
238 Returns:
239 None
240 """
241 return None
243 def reset_buffers_nonMPI(self, controller: 'Controller', **kwargs: Any) -> None:
244 """
245 Buffers refer to variables used across multiple steps that are stored in the convergence controller classes to
246 imitate communication in non MPI versions. These have to be reset in order to replicate availability of
247 variables in MPI versions.
249 For instance, if step 0 sets self.buffers.x = 1 from self.buffers.x = 0, when the same MPI rank uses the
250 variable with step 1, it will still carry the value of self.buffers.x = 1, equivalent to a send from the rank
251 computing step 0 to the rank computing step 1.
253 However, you can only receive what somebody sent and in order to make sure that is true for the non MPI
254 versions, we reset after each iteration so you cannot use this function to communicate backwards from the last
255 step to the first one for instance.
257 This function is called both at the end of instantiating the controller, as well as after each iteration.
259 Args:
260 controller (pySDC.Controller): The controller
262 Returns:
263 None
264 """
265 pass
267 def pre_iteration_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
268 """
269 Do whatever you want to before each iteration here.
271 Args:
272 controller (pySDC.Controller): The controller
273 S (pySDC.Step): The current step
275 Returns:
276 None
277 """
278 pass
280 def post_iteration_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
281 """
282 Do whatever you want to after each iteration here.
284 Args:
285 controller (pySDC.Controller): The controller
286 S (pySDC.Step): The current step
288 Returns:
289 None
290 """
291 pass
293 def post_step_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
294 """
295 Do whatever you want to after each step here.
297 Args:
298 controller (pySDC.Controller): The controller
299 S (pySDC.Step): The current step
301 Returns:
302 None
303 """
304 pass
306 def post_run_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
307 """
308 Do whatever you want to after the run here.
310 Args:
311 controller (pySDC.Controller): The controller
312 S (pySDC.Step): The current step
314 Returns:
315 None
316 """
317 pass
319 def prepare_next_block(
320 self, controller: 'Controller', S: 'Step', size: int, time: Any, Tend: float, **kwargs: Any
321 ) -> None:
322 """
323 Prepare stuff like spreading step sizes or whatever.
325 Args:
326 controller (pySDC.Controller): The controller
327 S (pySDC.Step): The current step
328 size (int): The number of ranks
329 time (float): The current time will be list in nonMPI controller implementation
330 Tend (float): The final time
332 Returns:
333 None
334 """
335 pass
337 def convergence_control(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
338 """
339 Call all the functions related to convergence control.
340 This is called in `it_check` in the controller after every iteration just after `post_iteration_processing`.
341 Args:
342 controller (pySDC.Controller): The controller
343 S (pySDC.Step): The current step
345 Returns:
346 None
347 """
349 self.get_new_step_size(controller, S, **kwargs)
350 self.determine_restart(controller, S, **kwargs)
351 self.check_iteration_status(controller, S, **kwargs)
353 return None
355 def post_spread_processing(self, controller: 'Controller', S: 'Step', **kwargs: Any) -> None:
356 """
357 This function is called at the end of the `SPREAD` stage in the controller
359 Args:
360 controller (pySDC.Controller): The controller
361 S (pySDC.Step): The current step
362 """
363 pass
365 def send(self, comm, dest, data, blocking=False, **kwargs):
366 """
367 Send data to a different rank
369 Args:
370 comm (mpi4py.MPI.Intracomm): Communicator
371 dest (int): The target rank
372 data: Data to be sent
373 blocking (bool): Whether the communication is blocking or not
375 Returns:
376 request handle of the communication
377 """
378 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
380 # log what's happening for debug purposes
381 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}sends to step {dest} with tag {kwargs["tag"]}')
383 if blocking:
384 req = comm.send(data, dest=dest, **kwargs)
385 else:
386 req = comm.isend(data, dest=dest, **kwargs)
388 return req
390 def recv(self, comm, source, **kwargs):
391 """
392 Receive some data
394 Args:
395 comm (mpi4py.MPI.Intracomm): Communicator
396 source (int): Where to look for receiving
398 Returns:
399 whatever has been received
400 """
401 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
403 # log what's happening for debug purposes
404 self.logger.debug(f'Step {comm.rank} receives from step {source} with tag {kwargs["tag"]}')
406 data = comm.recv(source=source, **kwargs)
408 return data
410 def Send(self, comm, dest, buffer, blocking=False, **kwargs):
411 """
412 Send data to a different rank
414 Args:
415 comm (mpi4py.MPI.Intracomm): Communicator
416 dest (int): The target rank
417 buffer: Buffer for the data
418 blocking (bool): Whether the communication is blocking or not
420 Returns:
421 request handle of the communication
422 """
423 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
425 # log what's happening for debug purposes
426 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}Sends to step {dest} with tag {kwargs["tag"]}')
428 if blocking:
429 req = comm.Send(buffer, dest=dest, **kwargs)
430 else:
431 req = comm.Isend(buffer, dest=dest, **kwargs)
433 return req
435 def Recv(self, comm, source, buffer, **kwargs):
436 """
437 Receive some data
439 Args:
440 comm (mpi4py.MPI.Intracomm): Communicator
441 source (int): Where to look for receiving
443 Returns:
444 whatever has been received
445 """
446 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
448 # log what's happening for debug purposes
449 self.logger.debug(f'Step {comm.rank} Receives from step {source} with tag {kwargs["tag"]}')
451 data = comm.Recv(buffer, source=source, **kwargs)
453 return data
455 def add_status_variable_to_step(self, key, value=None):
456 if type(self.controller).__name__ == 'controller_MPI':
457 steps = [self.controller.S]
458 else:
459 steps = self.controller.MS
461 steps[0].status.add_attr(key)
463 if value is not None:
464 self.set_step_status_variable(key, value)
466 def set_step_status_variable(self, key, value):
467 if type(self.controller).__name__ == 'controller_MPI':
468 steps = [self.controller.S]
469 else:
470 steps = self.controller.MS
472 for S in steps:
473 S.status.__dict__[key] = value
475 def add_status_variable_to_level(self, key, value=None):
476 if type(self.controller).__name__ == 'controller_MPI':
477 steps = [self.controller.S]
478 else:
479 steps = self.controller.MS
481 steps[0].levels[0].status.add_attr(key)
483 if value is not None:
484 self.set_level_status_variable(key, value)
486 def set_level_status_variable(self, key, value):
487 if type(self.controller).__name__ == 'controller_MPI':
488 steps = [self.controller.S]
489 else:
490 steps = self.controller.MS
492 for S in steps:
493 for L in S.levels:
494 L.status.__dict__[key] = value