Coverage for pySDC/implementations/convergence_controller_classes/crash.py: 98%
44 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1from pySDC.core.convergence_controller import ConvergenceController
2from pySDC.core.errors import ConvergenceError
3import numpy as np
4import time
7class CrashBase(ConvergenceController):
8 """
9 Crash the code across all ranks
10 """
12 def __init__(self, controller, params, description, **kwargs):
13 super().__init__(controller, params, description, **kwargs)
14 if self.comm or self.params.useMPI:
15 from mpi4py import MPI
17 self.MPI_OR = MPI.LOR
19 def communicate_crash(self, crash, msg='', comm=None, **kwargs):
20 """
21 Communicate a crash across all ranks and raise an error if so.
23 Args:
24 crash (bool): If this rank wants to crash
25 comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable:
26 """
28 # communicate across the sweeper
29 if self.comm:
30 crash = self.comm.allreduce(crash, op=self.MPI_OR)
32 # communicate across the steps
33 if comm:
34 crash = comm.allreduce(crash, op=self.MPI_OR)
36 if crash:
37 raise ConvergenceError(msg)
40class StopAtNan(CrashBase):
41 """
42 Crash the code when the norm of the solution exceeds some limit or contains nan.
43 This class is useful when running with MPI in the sweeper or controller.
44 """
46 def setup(self, controller, params, description, **kwargs):
47 """
48 Define parameters here.
50 Default parameters are:
51 - thresh (float): Crash the code when the norm of the solution exceeds this threshold
53 Args:
54 controller (pySDC.Controller): The controller
55 params (dict): The params passed for this specific convergence controller
56 description (dict): The description object used to instantiate the controller
58 Returns:
59 (dict): The updated params dictionary
60 """
61 self.comm = description['sweeper_params'].get('comm', None)
62 defaults = {
63 "control_order": 94,
64 "thresh": np.inf,
65 }
67 return {**defaults, **super().setup(controller, params, description, **kwargs)}
69 def prepare_next_block(self, controller, S, *args, **kwargs):
70 """
71 Check if we need to crash the code.
73 Args:
74 controller (pySDC.Controller.controller): Controller
75 S (pySDC.Step.step): Step
76 comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable
78 Raises:
79 ConvergenceError: If the solution does not fall within the allowed space
80 """
81 isfinite, below_limit = True, True
82 crash = False
84 for lvl in S.levels:
85 for u in lvl.u:
86 if u is None:
87 break
88 isfinite = np.all(np.isfinite(u))
90 below_limit = abs(u) < self.params.thresh
92 crash = not (isfinite and below_limit)
94 if crash:
95 break
96 if crash:
97 break
99 self.communicate_crash(crash, msg=f'Solution exceeds bounds! Crashing code at {S.time}!', **kwargs)
102class StopAtMaxRuntime(CrashBase):
103 """
104 Abort the code when the problem has exceeded a maximum runtime.
105 """
107 def setup(self, controller, params, description, **kwargs):
108 """
109 Define parameters here.
111 Default parameters are:
112 - max_runtime (float): Crash the code when the norm of the runtime exceeds this threshold
114 Args:
115 controller (pySDC.Controller): The controller
116 params (dict): The params passed for this specific convergence controller
117 description (dict): The description object used to instantiate the controller
119 Returns:
120 (dict): The updated params dictionary
121 """
122 self.comm = description['sweeper_params'].get('comm', None)
123 defaults = {
124 "control_order": 94,
125 "max_runtime": np.inf,
126 }
127 self.t0 = time.perf_counter()
129 return {**defaults, **super().setup(controller, params, description, **kwargs)}
131 def prepare_next_block(self, controller, S, *args, **kwargs):
132 """
133 Check if we need to crash the code.
135 Args:
136 controller (pySDC.Controller.controller): Controller
137 S (pySDC.Step.step): Step
138 comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable
140 Raises:
141 ConvergenceError: If the solution does not fall within the allowed space
142 """
143 self.communicate_crash(
144 crash=abs(self.t0 - time.perf_counter()) > self.params.max_runtime,
145 msg=f'Exceeding max. runtime of {self.params.max_runtime}s! Crashing code at {S.time}!',
146 **kwargs,
147 )