Coverage for pySDC/core/convergence_controller.py: 84%
113 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import logging
2from pySDC.helpers.pysdc_helper import FrozenClass
5# short helper class to add params as attributes
6class Pars(FrozenClass):
7 def __init__(self, params):
8 self.control_order = 0 # integer that determines the order in which the convergence controllers are called
9 self.useMPI = None # depends on the controller
11 for k, v in params.items():
12 setattr(self, k, v)
14 self._freeze()
17# short helper class to store status variables
18class Status(FrozenClass):
19 """
20 Initialize status variables with None, since at the time of instantiation of the convergence controllers, not all
21 relevant information about the controller are known.
22 """
24 def __init__(self, status_variabes):
25 [setattr(self, key, None) for key in status_variabes]
27 self._freeze()
30class ConvergenceController(object):
31 """
32 Base abstract class for convergence controller, which is plugged into the controller to determine the iteration
33 count and time step size.
34 """
36 def __init__(self, controller, params, description, **kwargs):
37 """
38 Initialization routine
40 Args:
41 controller (pySDC.Controller): The controller
42 params (dict): The params passed for this specific convergence controller
43 description (dict): The description object used to instantiate the controller
44 """
45 self.controller = controller
46 self.params = Pars(self.setup(controller, params, description))
47 params_ok, msg = self.check_parameters(controller, params, description)
48 assert params_ok, f'{type(self).__name__} -- {msg}'
49 self.dependencies(controller, description)
50 self.logger = logging.getLogger(f"{type(self).__name__}")
52 if self.params.useMPI:
53 self.prepare_MPI_datatypes()
55 def prepare_MPI_logical_operations(self):
56 """
57 Prepare MPI logical operations so we don't need to import mpi4py all the time
58 """
59 from mpi4py import MPI
61 self.MPI_LAND = MPI.LAND
62 self.MPI_LOR = MPI.LOR
64 def prepare_MPI_datatypes(self):
65 """
66 Prepare MPI datatypes so we don't need to import mpi4py all the time
67 """
68 from mpi4py import MPI
70 self.MPI_INT = MPI.INT
71 self.MPI_DOUBLE = MPI.DOUBLE
72 self.MPI_BOOL = MPI.BOOL
74 def log(self, msg, S, level=15, **kwargs):
75 """
76 Shortcut that has a default level for the logger. 15 is above debug but below info.
78 Args:
79 msg (str): Message you want to log
80 S (pySDC.step): The current step
81 level (int): the level passed to the logger
83 Returns:
84 None
85 """
86 self.logger.log(level, f'Process {S.status.slot:2d} on time {S.time:.6f} - {msg}')
87 return None
89 def debug(self, msg, S, **kwargs):
90 """
91 Shortcut to pass messages at debug level to the logger.
93 Args:
94 msg (str): Message you want to log
95 S (pySDC.step): The current step
97 Returns:
98 None
99 """
100 self.log(msg=msg, S=S, level=10, **kwargs)
101 return None
103 def setup(self, controller, params, description, **kwargs):
104 """
105 Setup various variables that only need to be set once in the beginning.
106 If the convergence controller is added automatically, you can give it params by adding it manually.
107 It will be instantiated only once with the manually supplied parameters overriding automatically added
108 parameters.
110 This function scans the convergence controllers supplied to the description object for instances of itself.
111 This corresponds to the convergence controller being added manually by the user. If something is found, this
112 function will then return a composite dictionary from the `params` passed to this function as well as the
113 `params` passed manually, with priority to manually added parameters. If you added the convergence controller
114 manually, that is of course the same and nothing happens. If, on the other hand, the convergence controller was
115 added automatically, the `params` passed here will come from whatever added it and you can now override
116 parameters by adding the convergence controller manually.
117 This relies on children classes to return a composite dictionary from their defaults and from the result of this
118 function, so you should write
119 ```
120 return {**defaults, **super().setup(controller, params, description, **kwargs)}
121 ```
122 when overloading this method in a child class, with `defaults` a dictionary containing default parameters.
124 Args:
125 controller (pySDC.Controller): The controller
126 params (dict): The params passed for this specific convergence controller
127 description (dict): The description object used to instantiate the controller
129 Returns:
130 (dict): The updated params dictionary after setup
131 """
132 # allow to change parameters by adding the convergence controller manually
133 return {**params, **description.get('convergence_controllers', {}).get(type(self), {})}
135 def dependencies(self, controller, description, **kwargs):
136 """
137 Load dependencies on other convergence controllers here.
139 Args:
140 controller (pySDC.Controller): The controller
141 description (dict): The description object used to instantiate the controller
143 Returns:
144 None
145 """
146 pass
148 def check_parameters(self, controller, params, description, **kwargs):
149 """
150 Check whether parameters are compatible with whatever assumptions went into the step size functions etc.
152 Args:
153 controller (pySDC.Controller): The controller
154 params (dict): The params passed for this specific convergence controller
155 description (dict): The description object used to instantiate the controller
157 Returns:
158 bool: Whether the parameters are compatible
159 str: The error message
160 """
161 return True, ""
163 def check_iteration_status(self, controller, S, **kwargs):
164 """
165 Determine whether to keep iterating or not in this function.
167 Args:
168 controller (pySDC.Controller): The controller
169 S (pySDC.Step): The current step
171 Returns:
172 None
173 """
174 pass
176 def get_new_step_size(self, controller, S, **kwargs):
177 """
178 This function allows to set a step size with arbitrary criteria.
179 Make sure to give an order to the convergence controller by setting the `control_order` variable in the params.
180 This variable is an integer and you can see what the current order is by using
181 `controller.print_convergence_controllers()`.
183 Args:
184 controller (pySDC.Controller): The controller
185 S (pySDC.Step): The current step
187 Returns:
188 None
189 """
190 pass
192 def determine_restart(self, controller, S, **kwargs):
193 """
194 Determine for each step separately if it wants to be restarted for whatever reason.
196 Args:
197 controller (pySDC.Controller): The controller
198 S (pySDC.Step): The current step
200 Returns:
201 None
202 """
203 pass
205 def reset_status_variables(self, controller, **kwargs):
206 """
207 Reset status variables.
208 This is called in the `restart_block` function.
209 Args:
210 controller (pySDC.Controller): The controller
212 Returns:
213 None
214 """
215 return None
217 def setup_status_variables(self, controller, **kwargs):
218 """
219 Setup status variables.
220 This is not done at the time of instantiation, since the controller is not fully instantiated at that time and
221 hence not all information are available. Instead, this function is called after the controller has been fully
222 instantiated.
224 Args:
225 controller (pySDC.Controller): The controller
227 Returns:
228 None
229 """
230 return None
232 def reset_buffers_nonMPI(self, controller, **kwargs):
233 """
234 Buffers refer to variables used across multiple steps that are stored in the convergence controller classes to
235 imitate communication in non MPI versions. These have to be reset in order to replicate availability of
236 variables in MPI versions.
238 For instance, if step 0 sets self.buffers.x = 1 from self.buffers.x = 0, when the same MPI rank uses the
239 variable with step 1, it will still carry the value of self.buffers.x = 1, equivalent to a send from the rank
240 computing step 0 to the rank computing step 1.
242 However, you can only receive what somebody sent and in order to make sure that is true for the non MPI
243 versions, we reset after each iteration so you cannot use this function to communicate backwards from the last
244 step to the first one for instance.
246 This function is called both at the end of instantiating the controller, as well as after each iteration.
248 Args:
249 controller (pySDC.Controller): The controller
251 Returns:
252 None
253 """
254 pass
256 def pre_iteration_processing(self, controller, S, **kwargs):
257 """
258 Do whatever you want to before each iteration here.
260 Args:
261 controller (pySDC.Controller): The controller
262 S (pySDC.Step): The current step
264 Returns:
265 None
266 """
267 pass
269 def post_iteration_processing(self, controller, S, **kwargs):
270 """
271 Do whatever you want to after each iteration here.
273 Args:
274 controller (pySDC.Controller): The controller
275 S (pySDC.Step): The current step
277 Returns:
278 None
279 """
280 pass
282 def post_step_processing(self, controller, S, **kwargs):
283 """
284 Do whatever you want to after each step here.
286 Args:
287 controller (pySDC.Controller): The controller
288 S (pySDC.Step): The current step
290 Returns:
291 None
292 """
293 pass
295 def post_run_processing(self, controller, S, **kwargs):
296 """
297 Do whatever you want to after the run here.
299 Args:
300 controller (pySDC.Controller): The controller
301 S (pySDC.Step): The current step
303 Returns:
304 None
305 """
306 pass
308 def prepare_next_block(self, controller, S, size, time, Tend, **kwargs):
309 """
310 Prepare stuff like spreading step sizes or whatever.
312 Args:
313 controller (pySDC.Controller): The controller
314 S (pySDC.Step): The current step
315 size (int): The number of ranks
316 time (float): The current time will be list in nonMPI controller implementation
317 Tend (float): The final time
319 Returns:
320 None
321 """
322 pass
324 def convergence_control(self, controller, S, **kwargs):
325 """
326 Call all the functions related to convergence control.
327 This is called in `it_check` in the controller after every iteration just after `post_iteration_processing`.
328 Args:
329 controller (pySDC.Controller): The controller
330 S (pySDC.Step): The current step
332 Returns:
333 None
334 """
336 self.get_new_step_size(controller, S, **kwargs)
337 self.determine_restart(controller, S, **kwargs)
338 self.check_iteration_status(controller, S, **kwargs)
340 return None
342 def post_spread_processing(self, controller, S, **kwargs):
343 """
344 This function is called at the end of the `SPREAD` stage in the controller
346 Args:
347 controller (pySDC.Controller): The controller
348 S (pySDC.Step): The current step
349 """
350 pass
352 def send(self, comm, dest, data, blocking=False, **kwargs):
353 """
354 Send data to a different rank
356 Args:
357 comm (mpi4py.MPI.Intracomm): Communicator
358 dest (int): The target rank
359 data: Data to be sent
360 blocking (bool): Whether the communication is blocking or not
362 Returns:
363 request handle of the communication
364 """
365 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
367 # log what's happening for debug purposes
368 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}sends to step {dest} with tag {kwargs["tag"]}')
370 if blocking:
371 req = comm.send(data, dest=dest, **kwargs)
372 else:
373 req = comm.isend(data, dest=dest, **kwargs)
375 return req
377 def recv(self, comm, source, **kwargs):
378 """
379 Receive some data
381 Args:
382 comm (mpi4py.MPI.Intracomm): Communicator
383 source (int): Where to look for receiving
385 Returns:
386 whatever has been received
387 """
388 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
390 # log what's happening for debug purposes
391 self.logger.debug(f'Step {comm.rank} receives from step {source} with tag {kwargs["tag"]}')
393 data = comm.recv(source=source, **kwargs)
395 return data
397 def Send(self, comm, dest, buffer, blocking=False, **kwargs):
398 """
399 Send data to a different rank
401 Args:
402 comm (mpi4py.MPI.Intracomm): Communicator
403 dest (int): The target rank
404 buffer: Buffer for the data
405 blocking (bool): Whether the communication is blocking or not
407 Returns:
408 request handle of the communication
409 """
410 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
412 # log what's happening for debug purposes
413 self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}Sends to step {dest} with tag {kwargs["tag"]}')
415 if blocking:
416 req = comm.Send(buffer, dest=dest, **kwargs)
417 else:
418 req = comm.Isend(buffer, dest=dest, **kwargs)
420 return req
422 def Recv(self, comm, source, buffer, **kwargs):
423 """
424 Receive some data
426 Args:
427 comm (mpi4py.MPI.Intracomm): Communicator
428 source (int): Where to look for receiving
430 Returns:
431 whatever has been received
432 """
433 kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))
435 # log what's happening for debug purposes
436 self.logger.debug(f'Step {comm.rank} Receives from step {source} with tag {kwargs["tag"]}')
438 data = comm.Recv(buffer, source=source, **kwargs)
440 return data
442 def add_status_variable_to_step(self, key, value=None):
443 if type(self.controller).__name__ == 'controller_MPI':
444 steps = [self.controller.S]
445 else:
446 steps = self.controller.MS
448 steps[0].status.add_attr(key)
450 if value is not None:
451 self.set_step_status_variable(key, value)
453 def set_step_status_variable(self, key, value):
454 if type(self.controller).__name__ == 'controller_MPI':
455 steps = [self.controller.S]
456 else:
457 steps = self.controller.MS
459 for S in steps:
460 S.status.__dict__[key] = value
462 def add_status_variable_to_level(self, key, value=None):
463 if type(self.controller).__name__ == 'controller_MPI':
464 steps = [self.controller.S]
465 else:
466 steps = self.controller.MS
468 steps[0].levels[0].status.add_attr(key)
470 if value is not None:
471 self.set_level_status_variable(key, value)
473 def set_level_status_variable(self, key, value):
474 if type(self.controller).__name__ == 'controller_MPI':
475 steps = [self.controller.S]
476 else:
477 steps = self.controller.MS
479 for S in steps:
480 for L in S.levels:
481 L.status.__dict__[key] = value