Coverage for pySDC/implementations/convergence_controller_classes/basic_restarting.py: 97%
99 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1from pySDC.core.convergence_controller import ConvergenceController, Pars
2from pySDC.implementations.convergence_controller_classes.spread_step_sizes import (
3 SpreadStepSizesBlockwise,
4)
5from pySDC.core.errors import ConvergenceError
6import numpy as np
9class BasicRestarting(ConvergenceController):
10 """
11 Class with some utilities for restarting. The specific functions are:
12 - Telling each step after one that requested a restart to get restarted as well
13 - Allowing each step to be restarted a limited number of times in a row before just moving on anyways
15 Default control order is 95.
16 """
18 @classmethod
19 def get_implementation(cls, useMPI):
20 """
21 Retrieve the implementation for a specific flavor of this class.
23 Args:
24 useMPI (bool): Whether or not the controller uses MPI
26 Returns:
27 cls: The child class that implements the desired flavor
28 """
29 if useMPI:
30 return BasicRestartingMPI
31 else:
32 return BasicRestartingNonMPI
34 def __init__(self, controller, params, description, **kwargs):
35 """
36 Initialization routine
38 Args:
39 controller (pySDC.Controller): The controller
40 params (dict): Parameters for the convergence controller
41 description (dict): The description object used to instantiate the controller
42 """
43 super().__init__(controller, params, description)
44 self.buffers = Pars({"restart": False, "max_restart_reached": False})
46 def setup(self, controller, params, description, **kwargs):
47 """
48 Define parameters here.
50 Default parameters are:
51 - control_order (int): The order relative to other convergence controllers
52 - max_restarts (int): Maximum number of restarts we allow each step before we just move on with whatever we
53 have
54 - step_size_spreader (pySDC.ConvergenceController): A convergence controller that takes care of distributing
55 the steps sizes between blocks
57 Args:
58 controller (pySDC.Controller): The controller
59 params (dict): The params passed for this specific convergence controller
60 description (dict): The description object used to instantiate the controller
62 Returns:
63 (dict): The updated params dictionary
64 """
65 defaults = {
66 "control_order": 95,
67 "max_restarts": 10,
68 "crash_after_max_restarts": True,
69 "restart_from_first_step": False,
70 "step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']),
71 }
73 from pySDC.implementations.hooks.log_restarts import LogRestarts
75 controller.add_hook(LogRestarts)
77 return {**defaults, **super().setup(controller, params, description, **kwargs)}
79 def setup_status_variables(self, *args, **kwargs):
80 """
81 Add status variables for whether to restart now and how many times the step has been restarted in a row to the
82 Steps
84 Returns:
85 None
86 """
87 self.add_status_variable_to_step('restart', False)
88 self.add_status_variable_to_step('restarts_in_a_row', 0)
90 def reset_status_variables(self, *args, **kwargs):
91 """
92 Add status variables for whether to restart now and how many times the step has been restarted in a row to the
93 Steps
95 Returns:
96 None
97 """
98 self.set_step_status_variable('restart', False)
100 def dependencies(self, controller, description, **kwargs):
101 """
102 Load a convergence controller that spreads the step sizes between steps.
104 Args:
105 controller (pySDC.Controller): The controller
106 description (dict): The description object used to instantiate the controller
108 Returns:
109 None
110 """
111 spread_step_sizes_params = {
112 'spread_from_first_restarted': not self.params.restart_from_first_step,
113 }
114 controller.add_convergence_controller(
115 self.params.step_size_spreader, description=description, params=spread_step_sizes_params
116 )
117 return None
119 def determine_restart(self, controller, S, **kwargs):
120 """
121 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
122 with the restarts and want to move on anyways.
124 Args:
125 controller (pySDC.Controller): The controller
126 S (pySDC.Step): The current step
128 Returns:
129 None
130 """
131 raise NotImplementedError("Please implement a function to determine if we need a restart here!")
134class BasicRestartingNonMPI(BasicRestarting):
135 """
136 Non-MPI specific version of basic restarting
137 """
139 def reset_buffers_nonMPI(self, controller, **kwargs):
140 """
141 Reset all variables with are used to simulate communication here
143 Args:
144 controller (pySDC.Controller): The controller
146 Returns:
147 None
148 """
149 self.buffers.restart = False
150 self.buffers.max_restart_reached = False
152 return None
154 def determine_restart(self, controller, S, MS, **kwargs):
155 """
156 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
157 with the restarts and want to move on anyways.
159 Args:
160 controller (pySDC.Controller): The controller
161 S (pySDC.Step): The current step
162 MS (list): List of active steps
164 Returns:
165 None
166 """
167 # check if we performed too many restarts
168 if S.status.first:
169 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts
171 if self.buffers.max_restart_reached and S.status.restart:
172 if self.params.crash_after_max_restarts:
173 raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.")
174 self.log(
175 f"Step(s) restarted {S.status.restarts_in_a_row} \
176 time(s) already, maximum reached, moving \
177on...",
178 S,
179 )
181 self.buffers.restart = S.status.restart or self.buffers.restart
182 S.status.restart = (S.status.restart or self.buffers.restart) and not self.buffers.max_restart_reached
184 if S.status.last and self.params.restart_from_first_step and not self.buffers.max_restart_reached:
185 for step in MS:
186 step.status.restart = self.buffers.restart
188 return None
190 def prepare_next_block(self, controller, S, size, time, Tend, MS, **kwargs):
191 """
192 Update restarts in a row for all steps.
194 Args:
195 controller (pySDC.Controller): The controller
196 S (pySDC.Step): The current step
197 size (int): The number of ranks
198 time (list): List containing the time of all the steps
199 Tend (float): Final time of the simulation
200 MS (list): List of active steps
202 Returns:
203 None
204 """
205 if S not in MS:
206 return None
208 restart_from = min([me.status.slot for me in MS if me.status.restart] + [size - 1])
210 if S.status.slot < restart_from:
211 MS[restart_from - S.status.slot].status.restarts_in_a_row = 0
212 else:
213 step = MS[S.status.slot - restart_from]
214 step.status.restarts_in_a_row = S.status.restarts_in_a_row + 1 if S.status.restart else 0
216 return None
219class BasicRestartingMPI(BasicRestarting):
220 """
221 MPI specific version of basic restarting
222 """
224 def __init__(self, controller, params, description, **kwargs):
225 """
226 Initialization routine. Adds a buffer.
228 Args:
229 controller (pySDC.Controller): The controller
230 params (dict): Parameters for the convergence controller
231 description (dict): The description object used to instantiate the controller
232 """
233 from mpi4py import MPI
235 self.OR = MPI.LOR
237 super().__init__(controller, params, description)
238 self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False})
240 def determine_restart(self, controller, S, comm, **kwargs):
241 """
242 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
243 with the restarts and want to move on anyways.
245 Args:
246 controller (pySDC.Controller): The controller
247 S (pySDC.Step): The current step
248 comm (mpi4py.MPI.Intracomm): Communicator
250 Returns:
251 None
252 """
253 crash_now = False
255 if S.status.first:
256 # check if we performed too many restarts
257 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts
258 self.buffers.restart_earlier = False # there is no earlier step
260 if self.buffers.max_restart_reached and S.status.restart:
261 if self.params.crash_after_max_restarts:
262 crash_now = True
263 self.log(
264 f"Step(s) restarted {S.status.restarts_in_a_row} \
265 time(s) already, maximum reached, moving \
266on...",
267 S,
268 )
269 elif not S.status.prev_done and not self.params.restart_from_first_step:
270 # receive information about restarts from earlier ranks
271 buff = np.empty(3, dtype=bool)
272 self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
273 self.buffers.restart_earlier = buff[0]
274 self.buffers.max_restart_reached = buff[1]
275 crash_now = buff[2]
277 # decide whether to restart
278 S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached
280 # send information about restarts forward
281 if not S.status.last and not self.params.restart_from_first_step:
282 buff = np.empty(3, dtype=bool)
283 buff[0] = S.status.restart
284 buff[1] = self.buffers.max_restart_reached
285 buff[2] = crash_now
286 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])
288 if self.params.restart_from_first_step:
289 max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0)
290 S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached
292 if crash_now:
293 raise ConvergenceError("Surrendering because of too many restarts...")
295 return None
297 def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs):
298 """
299 Update restarts in a row for all steps.
301 Args:
302 controller (pySDC.Controller): The controller
303 S (pySDC.Step): The current step
304 size (int): The number of ranks
305 time (list): List containing the time of all the steps
306 Tend (float): Final time of the simulation
307 comm (mpi4py.MPI.Intracomm): Communicator
309 Returns:
310 None
311 """
313 restart_from = min(comm.allgather(S.status.slot if S.status.restart else S.status.time_size - 1))
315 # send "backward" the number of restarts in a row
316 if S.status.slot >= restart_from:
317 buff = np.empty(1, dtype=int)
318 buff[0] = int(S.status.restarts_in_a_row + 1 if S.status.restart else 0)
319 self.Send(
320 comm,
321 dest=S.status.slot - restart_from,
322 buffer=[buff, self.MPI_INT],
323 blocking=False,
324 )
326 # receive new number of restarts in a row
327 if S.status.slot + restart_from < size:
328 buff = np.empty(1, dtype=int)
329 self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT])
330 S.status.restarts_in_a_row = buff[0]
331 else:
332 S.status.restarts_in_a_row = 0
334 return None