Coverage for pySDC/implementations/convergence_controller_classes/basic_restarting.py: 95%
99 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1from pySDC.core.convergence_controller import ConvergenceController, Pars
2from pySDC.implementations.convergence_controller_classes.spread_step_sizes import (
3 SpreadStepSizesBlockwise,
4)
5from pySDC.core.errors import ConvergenceError
6import numpy as np
9class BasicRestarting(ConvergenceController):
10 """
11 Class with some utilities for restarting. The specific functions are:
12 - Telling each step after one that requested a restart to get restarted as well
13 - Allowing each step to be restarted a limited number of times in a row before just moving on anyways
15 Default control order is 95.
16 """
18 @classmethod
19 def get_implementation(cls, useMPI):
20 """
21 Retrieve the implementation for a specific flavor of this class.
23 Args:
24 useMPI (bool): Whether or not the controller uses MPI
26 Returns:
27 cls: The child class that implements the desired flavor
28 """
29 if useMPI:
30 return BasicRestartingMPI
31 else:
32 return BasicRestartingNonMPI
34 def __init__(self, controller, params, description, **kwargs):
35 """
36 Initialization routine
38 Args:
39 controller (pySDC.Controller): The controller
40 params (dict): Parameters for the convergence controller
41 description (dict): The description object used to instantiate the controller
42 """
43 super().__init__(controller, params, description)
44 self.buffers = Pars({"restart": False, "max_restart_reached": False})
46 def setup(self, controller, params, description, **kwargs):
47 """
48 Define parameters here.
50 Default parameters are:
51 - control_order (int): The order relative to other convergence controllers
52 - max_restarts (int): Maximum number of restarts we allow each step before we just move on with whatever we
53 have
54 - step_size_spreader (pySDC.ConvergenceController): A convergence controller that takes care of distributing
55 the steps sizes between blocks
57 Args:
58 controller (pySDC.Controller): The controller
59 params (dict): The params passed for this specific convergence controller
60 description (dict): The description object used to instantiate the controller
62 Returns:
63 (dict): The updated params dictionary
64 """
65 defaults = {
66 "control_order": 95,
67 "max_restarts": 10,
68 "crash_after_max_restarts": True,
69 "restart_from_first_step": False,
70 "step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']),
71 }
73 from pySDC.implementations.hooks.log_restarts import LogRestarts
75 controller.add_hook(LogRestarts)
77 return {**defaults, **super().setup(controller, params, description, **kwargs)}
79 def setup_status_variables(self, *args, **kwargs):
80 """
81 Add status variables for whether to restart now and how many times the step has been restarted in a row to the
82 Steps
84 Returns:
85 None
86 """
87 self.add_status_variable_to_step('restart', False)
88 self.add_status_variable_to_step('restarts_in_a_row', 0)
90 def reset_status_variables(self, *args, **kwargs):
91 """
92 Add status variables for whether to restart now and how many times the step has been restarted in a row to the
93 Steps
95 Returns:
96 None
97 """
98 self.set_step_status_variable('restart', False)
100 def dependencies(self, controller, description, **kwargs):
101 """
102 Load a convergence controller that spreads the step sizes between steps.
104 Args:
105 controller (pySDC.Controller): The controller
106 description (dict): The description object used to instantiate the controller
108 Returns:
109 None
110 """
111 spread_step_sizes_params = {
112 'spread_from_first_restarted': not self.params.restart_from_first_step,
113 }
114 controller.add_convergence_controller(
115 self.params.step_size_spreader, description=description, params=spread_step_sizes_params
116 )
117 return None
119 def determine_restart(self, controller, S, **kwargs):
120 """
121 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
122 with the restarts and want to move on anyways.
124 Args:
125 controller (pySDC.Controller): The controller
126 S (pySDC.Step): The current step
128 Returns:
129 None
130 """
131 raise NotImplementedError("Please implement a function to determine if we need a restart here!")
134class BasicRestartingNonMPI(BasicRestarting):
135 """
136 Non-MPI specific version of basic restarting
137 """
139 def reset_buffers_nonMPI(self, controller, **kwargs):
140 """
141 Reset all variables with are used to simulate communication here
143 Args:
144 controller (pySDC.Controller): The controller
146 Returns:
147 None
148 """
149 self.buffers.restart = False
150 self.buffers.max_restart_reached = False
152 return None
154 def determine_restart(self, controller, S, MS, **kwargs):
155 """
156 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
157 with the restarts and want to move on anyways.
159 Args:
160 controller (pySDC.Controller): The controller
161 S (pySDC.Step): The current step
162 MS (list): List of active steps
164 Returns:
165 None
166 """
167 # check if we performed too many restarts
168 if S.status.first:
169 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts
171 if self.buffers.max_restart_reached and S.status.restart:
172 if self.params.crash_after_max_restarts:
173 raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.")
174 self.log(
175 f"Step(s) restarted {S.status.restarts_in_a_row} time(s) already, maximum reached, moving \
176on...",
177 S,
178 )
180 self.buffers.restart = S.status.restart or self.buffers.restart
181 S.status.restart = (S.status.restart or self.buffers.restart) and not self.buffers.max_restart_reached
183 if S.status.last and self.params.restart_from_first_step and not self.buffers.max_restart_reached:
184 for step in MS:
185 step.status.restart = self.buffers.restart
187 return None
189 def prepare_next_block(self, controller, S, size, time, Tend, MS, **kwargs):
190 """
191 Update restarts in a row for all steps.
193 Args:
194 controller (pySDC.Controller): The controller
195 S (pySDC.Step): The current step
196 size (int): The number of ranks
197 time (list): List containing the time of all the steps
198 Tend (float): Final time of the simulation
199 MS (list): List of active steps
201 Returns:
202 None
203 """
204 if S not in MS:
205 return None
207 restart_from = min([me.status.slot for me in MS if me.status.restart] + [size - 1])
209 if S.status.slot < restart_from:
210 MS[restart_from - S.status.slot].status.restarts_in_a_row = 0
211 else:
212 step = MS[S.status.slot - restart_from]
213 step.status.restarts_in_a_row = S.status.restarts_in_a_row + 1 if S.status.restart else 0
215 return None
218class BasicRestartingMPI(BasicRestarting):
219 """
220 MPI specific version of basic restarting
221 """
223 def __init__(self, controller, params, description, **kwargs):
224 """
225 Initialization routine. Adds a buffer.
227 Args:
228 controller (pySDC.Controller): The controller
229 params (dict): Parameters for the convergence controller
230 description (dict): The description object used to instantiate the controller
231 """
232 from mpi4py import MPI
234 self.OR = MPI.LOR
236 super().__init__(controller, params, description)
237 self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False})
239 def determine_restart(self, controller, S, comm, **kwargs):
240 """
241 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience
242 with the restarts and want to move on anyways.
244 Args:
245 controller (pySDC.Controller): The controller
246 S (pySDC.Step): The current step
247 comm (mpi4py.MPI.Intracomm): Communicator
249 Returns:
250 None
251 """
252 crash_now = False
254 if S.status.first:
255 # check if we performed too many restarts
256 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts
257 self.buffers.restart_earlier = False # there is no earlier step
259 if self.buffers.max_restart_reached and S.status.restart:
260 if self.params.crash_after_max_restarts:
261 crash_now = True
262 self.log(
263 f"Step(s) restarted {S.status.restarts_in_a_row} time(s) already, maximum reached, moving \
264on...",
265 S,
266 )
267 elif not S.status.prev_done and not self.params.restart_from_first_step:
268 # receive information about restarts from earlier ranks
269 buff = np.empty(3, dtype=bool)
270 self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
271 self.buffers.restart_earlier = buff[0]
272 self.buffers.max_restart_reached = buff[1]
273 crash_now = buff[2]
275 # decide whether to restart
276 S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached
278 # send information about restarts forward
279 if not S.status.last and not self.params.restart_from_first_step:
280 buff = np.empty(3, dtype=bool)
281 buff[0] = S.status.restart
282 buff[1] = self.buffers.max_restart_reached
283 buff[2] = crash_now
284 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])
286 if self.params.restart_from_first_step:
287 max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0)
288 S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached
290 if crash_now:
291 raise ConvergenceError("Surrendering because of too many restarts...")
293 return None
295 def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs):
296 """
297 Update restarts in a row for all steps.
299 Args:
300 controller (pySDC.Controller): The controller
301 S (pySDC.Step): The current step
302 size (int): The number of ranks
303 time (list): List containing the time of all the steps
304 Tend (float): Final time of the simulation
305 comm (mpi4py.MPI.Intracomm): Communicator
307 Returns:
308 None
309 """
311 restart_from = min(comm.allgather(S.status.slot if S.status.restart else S.status.time_size - 1))
313 # send "backward" the number of restarts in a row
314 if S.status.slot >= restart_from:
315 buff = np.empty(1, dtype=int)
316 buff[0] = int(S.status.restarts_in_a_row + 1 if S.status.restart else 0)
317 self.Send(
318 comm,
319 dest=S.status.slot - restart_from,
320 buffer=[buff, self.MPI_INT],
321 blocking=False,
322 )
324 # receive new number of restarts in a row
325 if S.status.slot + restart_from < size:
326 buff = np.empty(1, dtype=int)
327 self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT])
328 S.status.restarts_in_a_row = buff[0]
329 else:
330 S.status.restarts_in_a_row = 0
332 return None