Coverage for pySDC/implementations/convergence_controller_classes/crash.py: 98%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1from pySDC.core.convergence_controller import ConvergenceController 

2from pySDC.core.errors import ConvergenceError 

3import numpy as np 

4import time 

5 

6 

7class CrashBase(ConvergenceController): 

8 """ 

9 Crash the code across all ranks 

10 """ 

11 

12 def __init__(self, controller, params, description, **kwargs): 

13 super().__init__(controller, params, description, **kwargs) 

14 if self.comm or self.params.useMPI: 

15 from mpi4py import MPI 

16 

17 self.MPI_OR = MPI.LOR 

18 

19 def communicate_crash(self, crash, msg='', comm=None, **kwargs): 

20 """ 

21 Communicate a crash across all ranks and raise an error if so. 

22 

23 Args: 

24 crash (bool): If this rank wants to crash 

25 comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable: 

26 """ 

27 

28 # communicate across the sweeper 

29 if self.comm: 

30 crash = self.comm.allreduce(crash, op=self.MPI_OR) 

31 

32 # communicate across the steps 

33 if comm: 

34 crash = comm.allreduce(crash, op=self.MPI_OR) 

35 

36 if crash: 

37 raise ConvergenceError(msg) 

38 

39 

40class StopAtNan(CrashBase): 

41 """ 

42 Crash the code when the norm of the solution exceeds some limit or contains nan. 

43 This class is useful when running with MPI in the sweeper or controller. 

44 """ 

45 

46 def setup(self, controller, params, description, **kwargs): 

47 """ 

48 Define parameters here. 

49 

50 Default parameters are: 

51 - thresh (float): Crash the code when the norm of the solution exceeds this threshold 

52 

53 Args: 

54 controller (pySDC.Controller): The controller 

55 params (dict): The params passed for this specific convergence controller 

56 description (dict): The description object used to instantiate the controller 

57 

58 Returns: 

59 (dict): The updated params dictionary 

60 """ 

61 self.comm = description['sweeper_params'].get('comm', None) 

62 defaults = { 

63 "control_order": 94, 

64 "thresh": np.inf, 

65 } 

66 

67 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

68 

69 def prepare_next_block(self, controller, S, *args, **kwargs): 

70 """ 

71 Check if we need to crash the code. 

72 

73 Args: 

74 controller (pySDC.Controller.controller): Controller 

75 S (pySDC.Step.step): Step 

76 comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable 

77 

78 Raises: 

79 ConvergenceError: If the solution does not fall within the allowed space 

80 """ 

81 isfinite, below_limit = True, True 

82 crash = False 

83 

84 for lvl in S.levels: 

85 for u in lvl.u: 

86 if u is None: 

87 break 

88 isfinite = np.all(np.isfinite(u)) 

89 

90 below_limit = abs(u) < self.params.thresh 

91 

92 crash = not (isfinite and below_limit) 

93 

94 if crash: 

95 break 

96 if crash: 

97 break 

98 

99 self.communicate_crash(crash, msg=f'Solution exceeds bounds! Crashing code at {S.time}!', **kwargs) 

100 

101 

102class StopAtMaxRuntime(CrashBase): 

103 """ 

104 Abort the code when the problem has exceeded a maximum runtime. 

105 """ 

106 

107 def setup(self, controller, params, description, **kwargs): 

108 """ 

109 Define parameters here. 

110 

111 Default parameters are: 

112 - max_runtime (float): Crash the code when the norm of the runtime exceeds this threshold 

113 

114 Args: 

115 controller (pySDC.Controller): The controller 

116 params (dict): The params passed for this specific convergence controller 

117 description (dict): The description object used to instantiate the controller 

118 

119 Returns: 

120 (dict): The updated params dictionary 

121 """ 

122 self.comm = description['sweeper_params'].get('comm', None) 

123 defaults = { 

124 "control_order": 94, 

125 "max_runtime": np.inf, 

126 } 

127 self.t0 = time.perf_counter() 

128 

129 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

130 

131 def prepare_next_block(self, controller, S, *args, **kwargs): 

132 """ 

133 Check if we need to crash the code. 

134 

135 Args: 

136 controller (pySDC.Controller.controller): Controller 

137 S (pySDC.Step.step): Step 

138 comm (mpi4py.MPI.Intracomm or None): Communicator of the controller, if applicable 

139 

140 Raises: 

141 ConvergenceError: If the solution does not fall within the allowed space 

142 """ 

143 self.communicate_crash( 

144 crash=abs(self.t0 - time.perf_counter()) > self.params.max_runtime, 

145 msg=f'Exceeding max. runtime of {self.params.max_runtime}s! Crashing code at {S.time}!', 

146 **kwargs, 

147 )