Coverage for pySDC/implementations/convergence_controller_classes/basic_restarting.py: 97%
101 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1from pySDC.core.ConvergenceController import ConvergenceController, Pars
2from pySDC.implementations.convergence_controller_classes.spread_step_sizes import (
3 SpreadStepSizesBlockwise,
4)
5from pySDC.core.Errors import ConvergenceError
6import numpy as np
9class BasicRestarting(ConvergenceController):
10 """
11 Class with some utilities for restarting. The specific functions are:
12 - Telling each step after one that requested a restart to get restarted as well
13 - Allowing each step to be restarted a limited number of times in a row before just moving on anyways
15 Default control order is 95.
16 """
18 @classmethod
19 def get_implementation(cls, useMPI):
20 """
21 Retrieve the implementation for a specific flavor of this class.
23 Args:
24 useMPI (bool): Whether or not the controller uses MPI
26 Returns:
27 cls: The child class that implements the desired flavor
28 """
29 if useMPI:
30 return BasicRestartingMPI
31 else:
32 return BasicRestartingNonMPI
34 def __init__(self, controller, params, description, **kwargs):
35 """
36 Initialization routine
38 Args:
39 controller (pySDC.Controller): The controller
40 params (dict): Parameters for the convergence controller
41 description (dict): The description object used to instantiate the controller
42 """
43 super().__init__(controller, params, description)
44 self.buffers = Pars({"restart": False, "max_restart_reached": False})
46 def setup(self, controller, params, description, **kwargs):
47 """
48 Define parameters here.
50 Default parameters are:
51 - control_order (int): The order relative to other convergence controllers
52 - max_restarts (int): Maximum number of restarts we allow each step before we just move on with whatever we
53 have
54 - step_size_spreader (pySDC.ConvergenceController): A convergence controller that takes care of distributing
55 the steps sizes between blocks
57 Args:
58 controller (pySDC.Controller): The controller
59 params (dict): The params passed for this specific convergence controller
60 description (dict): The description object used to instantiate the controller
62 Returns:
63 (dict): The updated params dictionary
64 """
65 defaults = {
66 "control_order": 95,
67 "max_restarts": 10,
68 "crash_after_max_restarts": True,
69 "restart_from_first_step": False,
70 "step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']),
71 }
73 from pySDC.implementations.hooks.log_restarts import LogRestarts
75 controller.add_hook(LogRestarts)
77 return {**defaults, **super().setup(controller, params, description, **kwargs)}
79 def setup_status_variables(self, controller, **kwargs):
80 """
81 Add status variables for whether to restart now and how many times the step has been restarted in a row to the
82 Steps
84 Args:
85 controller (pySDC.Controller): The controller
86 reset (bool): Whether the function is called for the first time or to reset
88 Returns:
89 None
90 """
91 where = ["S" if 'comm' in kwargs.keys() else "MS", "status"]
92 self.add_variable(controller, name='restart', where=where, init=False)
93 self.add_variable(controller, name='restarts_in_a_row', where=where, init=0)
95 def reset_status_variables(self, controller, reset=False, **kwargs):
96 """
97 Add status variables for whether to restart now and how many times the step has been restarted in a row to the
98 Steps
100 Args:
101 controller (pySDC.Controller): The controller
102 reset (bool): Whether the function is called for the first time or to reset
104 Returns:
105 None
106 """
107 where = ["S" if 'comm' in kwargs.keys() else "MS", "status"]
108 self.reset_variable(controller, name='restart', where=where, init=False)
110 def dependencies(self, controller, description, **kwargs):
111 """
112 Load a convergence controller that spreads the step sizes between steps.
114 Args:
115 controller (pySDC.Controller): The controller
116 description (dict): The description object used to instantiate the controller
118 Returns:
119 None
120 """
121 spread_step_sizes_params = {
122 'spread_from_first_restarted': not self.params.restart_from_first_step,
123 }
124 controller.add_convergence_controller(
125 self.params.step_size_spreader, description=description, params=spread_step_sizes_params
126 )
127 return None
129 def determine_restart(self, controller, S, **kwargs):
130 """
131 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
132 with the restarts and want to move on anyways.
134 Args:
135 controller (pySDC.Controller): The controller
136 S (pySDC.Step): The current step
138 Returns:
139 None
140 """
141 raise NotImplementedError("Please implement a function to determine if we need a restart here!")
144class BasicRestartingNonMPI(BasicRestarting):
145 """
146 Non-MPI specific version of basic restarting
147 """
149 def reset_buffers_nonMPI(self, controller, **kwargs):
150 """
151 Reset all variables with are used to simulate communication here
153 Args:
154 controller (pySDC.Controller): The controller
156 Returns:
157 None
158 """
159 self.buffers.restart = False
160 self.buffers.max_restart_reached = False
162 return None
164 def determine_restart(self, controller, S, MS, **kwargs):
165 """
166 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
167 with the restarts and want to move on anyways.
169 Args:
170 controller (pySDC.Controller): The controller
171 S (pySDC.Step): The current step
172 MS (list): List of active steps
174 Returns:
175 None
176 """
177 # check if we performed too many restarts
178 if S.status.first:
179 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts
181 if self.buffers.max_restart_reached and S.status.restart:
182 if self.params.crash_after_max_restarts:
183 raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.")
184 self.log(
185 f"Step(s) restarted {S.status.restarts_in_a_row} \
186 time(s) already, maximum reached, moving \
187on...",
188 S,
189 )
191 self.buffers.restart = S.status.restart or self.buffers.restart
192 S.status.restart = (S.status.restart or self.buffers.restart) and not self.buffers.max_restart_reached
194 if S.status.last and self.params.restart_from_first_step and not self.buffers.max_restart_reached:
195 for step in MS:
196 step.status.restart = self.buffers.restart
198 return None
200 def prepare_next_block(self, controller, S, size, time, Tend, MS, **kwargs):
201 """
202 Update restarts in a row for all steps.
204 Args:
205 controller (pySDC.Controller): The controller
206 S (pySDC.Step): The current step
207 size (int): The number of ranks
208 time (list): List containing the time of all the steps
209 Tend (float): Final time of the simulation
210 MS (list): List of active steps
212 Returns:
213 None
214 """
215 if S not in MS:
216 return None
218 restart_from = min([me.status.slot for me in MS if me.status.restart] + [size - 1])
220 if S.status.slot < restart_from:
221 MS[restart_from - S.status.slot].status.restarts_in_a_row = 0
222 else:
223 step = MS[S.status.slot - restart_from]
224 step.status.restarts_in_a_row = S.status.restarts_in_a_row + 1 if S.status.restart else 0
226 return None
229class BasicRestartingMPI(BasicRestarting):
230 """
231 MPI specific version of basic restarting
232 """
234 def __init__(self, controller, params, description, **kwargs):
235 """
236 Initialization routine. Adds a buffer.
238 Args:
239 controller (pySDC.Controller): The controller
240 params (dict): Parameters for the convergence controller
241 description (dict): The description object used to instantiate the controller
242 """
243 from mpi4py import MPI
245 self.OR = MPI.LOR
247 super().__init__(controller, params, description)
248 self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False})
250 def determine_restart(self, controller, S, comm, **kwargs):
251 """
252 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
253 with the restarts and want to move on anyways.
255 Args:
256 controller (pySDC.Controller): The controller
257 S (pySDC.Step): The current step
258 comm (mpi4py.MPI.Intracomm): Communicator
260 Returns:
261 None
262 """
263 crash_now = False
265 if S.status.first:
266 # check if we performed too many restarts
267 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts
268 self.buffers.restart_earlier = False # there is no earlier step
270 if self.buffers.max_restart_reached and S.status.restart:
271 if self.params.crash_after_max_restarts:
272 crash_now = True
273 self.log(
274 f"Step(s) restarted {S.status.restarts_in_a_row} \
275 time(s) already, maximum reached, moving \
276on...",
277 S,
278 )
279 elif not S.status.prev_done and not self.params.restart_from_first_step:
280 # receive information about restarts from earlier ranks
281 buff = np.empty(3, dtype=bool)
282 self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
283 self.buffers.restart_earlier = buff[0]
284 self.buffers.max_restart_reached = buff[1]
285 crash_now = buff[2]
287 # decide whether to restart
288 S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached
290 # send information about restarts forward
291 if not S.status.last and not self.params.restart_from_first_step:
292 buff = np.empty(3, dtype=bool)
293 buff[0] = S.status.restart
294 buff[1] = self.buffers.max_restart_reached
295 buff[2] = crash_now
296 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])
298 if self.params.restart_from_first_step:
299 max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0)
300 S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached
302 if crash_now:
303 raise ConvergenceError("Surrendering because of too many restarts...")
305 return None
307 def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs):
308 """
309 Update restarts in a row for all steps.
311 Args:
312 controller (pySDC.Controller): The controller
313 S (pySDC.Step): The current step
314 size (int): The number of ranks
315 time (list): List containing the time of all the steps
316 Tend (float): Final time of the simulation
317 comm (mpi4py.MPI.Intracomm): Communicator
319 Returns:
320 None
321 """
323 restart_from = min(comm.allgather(S.status.slot if S.status.restart else S.status.time_size - 1))
325 # send "backward" the number of restarts in a row
326 if S.status.slot >= restart_from:
327 buff = np.empty(1, dtype=int)
328 buff[0] = int(S.status.restarts_in_a_row + 1 if S.status.restart else 0)
329 self.Send(
330 comm,
331 dest=S.status.slot - restart_from,
332 buffer=[buff, self.MPI_INT],
333 blocking=False,
334 )
336 # receive new number of restarts in a row
337 if S.status.slot + restart_from < size:
338 buff = np.empty(1, dtype=int)
339 self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT])
340 S.status.restarts_in_a_row = buff[0]
341 else:
342 S.status.restarts_in_a_row = 0
344 return None