Coverage for pySDC/implementations/convergence_controller_classes/basic_restarting.py: 97%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1from pySDC.core.convergence_controller import ConvergenceController, Pars 

2from pySDC.implementations.convergence_controller_classes.spread_step_sizes import ( 

3 SpreadStepSizesBlockwise, 

4) 

5from pySDC.core.errors import ConvergenceError 

6import numpy as np 

7 

8 

9class BasicRestarting(ConvergenceController): 

10 """ 

11 Class with some utilities for restarting. The specific functions are: 

12 - Telling each step after one that requested a restart to get restarted as well 

13 - Allowing each step to be restarted a limited number of times in a row before just moving on anyways 

14 

15 Default control order is 95. 

16 """ 

17 

18 @classmethod 

19 def get_implementation(cls, useMPI): 

20 """ 

21 Retrieve the implementation for a specific flavor of this class. 

22 

23 Args: 

24 useMPI (bool): Whether or not the controller uses MPI 

25 

26 Returns: 

27 cls: The child class that implements the desired flavor 

28 """ 

29 if useMPI: 

30 return BasicRestartingMPI 

31 else: 

32 return BasicRestartingNonMPI 

33 

34 def __init__(self, controller, params, description, **kwargs): 

35 """ 

36 Initialization routine 

37 

38 Args: 

39 controller (pySDC.Controller): The controller 

40 params (dict): Parameters for the convergence controller 

41 description (dict): The description object used to instantiate the controller 

42 """ 

43 super().__init__(controller, params, description) 

44 self.buffers = Pars({"restart": False, "max_restart_reached": False}) 

45 

46 def setup(self, controller, params, description, **kwargs): 

47 """ 

48 Define parameters here. 

49 

50 Default parameters are: 

51 - control_order (int): The order relative to other convergence controllers 

52 - max_restarts (int): Maximum number of restarts we allow each step before we just move on with whatever we 

53 have 

54 - step_size_spreader (pySDC.ConvergenceController): A convergence controller that takes care of distributing 

55 the steps sizes between blocks 

56 

57 Args: 

58 controller (pySDC.Controller): The controller 

59 params (dict): The params passed for this specific convergence controller 

60 description (dict): The description object used to instantiate the controller 

61 

62 Returns: 

63 (dict): The updated params dictionary 

64 """ 

65 defaults = { 

66 "control_order": 95, 

67 "max_restarts": 10, 

68 "crash_after_max_restarts": True, 

69 "restart_from_first_step": False, 

70 "step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']), 

71 } 

72 

73 from pySDC.implementations.hooks.log_restarts import LogRestarts 

74 

75 controller.add_hook(LogRestarts) 

76 

77 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

78 

79 def setup_status_variables(self, *args, **kwargs): 

80 """ 

81 Add status variables for whether to restart now and how many times the step has been restarted in a row to the 

82 Steps 

83 

84 Returns: 

85 None 

86 """ 

87 self.add_status_variable_to_step('restart', False) 

88 self.add_status_variable_to_step('restarts_in_a_row', 0) 

89 

90 def reset_status_variables(self, *args, **kwargs): 

91 """ 

92 Add status variables for whether to restart now and how many times the step has been restarted in a row to the 

93 Steps 

94 

95 Returns: 

96 None 

97 """ 

98 self.set_step_status_variable('restart', False) 

99 

100 def dependencies(self, controller, description, **kwargs): 

101 """ 

102 Load a convergence controller that spreads the step sizes between steps. 

103 

104 Args: 

105 controller (pySDC.Controller): The controller 

106 description (dict): The description object used to instantiate the controller 

107 

108 Returns: 

109 None 

110 """ 

111 spread_step_sizes_params = { 

112 'spread_from_first_restarted': not self.params.restart_from_first_step, 

113 } 

114 controller.add_convergence_controller( 

115 self.params.step_size_spreader, description=description, params=spread_step_sizes_params 

116 ) 

117 return None 

118 

119 def determine_restart(self, controller, S, **kwargs): 

120 """ 

121 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

122 with the restarts and want to move on anyways. 

123 

124 Args: 

125 controller (pySDC.Controller): The controller 

126 S (pySDC.Step): The current step 

127 

128 Returns: 

129 None 

130 """ 

131 raise NotImplementedError("Please implement a function to determine if we need a restart here!") 

132 

133 

134class BasicRestartingNonMPI(BasicRestarting): 

135 """ 

136 Non-MPI specific version of basic restarting 

137 """ 

138 

139 def reset_buffers_nonMPI(self, controller, **kwargs): 

140 """ 

141 Reset all variables with are used to simulate communication here 

142 

143 Args: 

144 controller (pySDC.Controller): The controller 

145 

146 Returns: 

147 None 

148 """ 

149 self.buffers.restart = False 

150 self.buffers.max_restart_reached = False 

151 

152 return None 

153 

154 def determine_restart(self, controller, S, MS, **kwargs): 

155 """ 

156 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

157 with the restarts and want to move on anyways. 

158 

159 Args: 

160 controller (pySDC.Controller): The controller 

161 S (pySDC.Step): The current step 

162 MS (list): List of active steps 

163 

164 Returns: 

165 None 

166 """ 

167 # check if we performed too many restarts 

168 if S.status.first: 

169 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts 

170 

171 if self.buffers.max_restart_reached and S.status.restart: 

172 if self.params.crash_after_max_restarts: 

173 raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.") 

174 self.log( 

175 f"Step(s) restarted {S.status.restarts_in_a_row} \ 

176 time(s) already, maximum reached, moving \ 

177on...", 

178 S, 

179 ) 

180 

181 self.buffers.restart = S.status.restart or self.buffers.restart 

182 S.status.restart = (S.status.restart or self.buffers.restart) and not self.buffers.max_restart_reached 

183 

184 if S.status.last and self.params.restart_from_first_step and not self.buffers.max_restart_reached: 

185 for step in MS: 

186 step.status.restart = self.buffers.restart 

187 

188 return None 

189 

190 def prepare_next_block(self, controller, S, size, time, Tend, MS, **kwargs): 

191 """ 

192 Update restarts in a row for all steps. 

193 

194 Args: 

195 controller (pySDC.Controller): The controller 

196 S (pySDC.Step): The current step 

197 size (int): The number of ranks 

198 time (list): List containing the time of all the steps 

199 Tend (float): Final time of the simulation 

200 MS (list): List of active steps 

201 

202 Returns: 

203 None 

204 """ 

205 if S not in MS: 

206 return None 

207 

208 restart_from = min([me.status.slot for me in MS if me.status.restart] + [size - 1]) 

209 

210 if S.status.slot < restart_from: 

211 MS[restart_from - S.status.slot].status.restarts_in_a_row = 0 

212 else: 

213 step = MS[S.status.slot - restart_from] 

214 step.status.restarts_in_a_row = S.status.restarts_in_a_row + 1 if S.status.restart else 0 

215 

216 return None 

217 

218 

219class BasicRestartingMPI(BasicRestarting): 

220 """ 

221 MPI specific version of basic restarting 

222 """ 

223 

224 def __init__(self, controller, params, description, **kwargs): 

225 """ 

226 Initialization routine. Adds a buffer. 

227 

228 Args: 

229 controller (pySDC.Controller): The controller 

230 params (dict): Parameters for the convergence controller 

231 description (dict): The description object used to instantiate the controller 

232 """ 

233 from mpi4py import MPI 

234 

235 self.OR = MPI.LOR 

236 

237 super().__init__(controller, params, description) 

238 self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False}) 

239 

240 def determine_restart(self, controller, S, comm, **kwargs): 

241 """ 

242 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

243 with the restarts and want to move on anyways. 

244 

245 Args: 

246 controller (pySDC.Controller): The controller 

247 S (pySDC.Step): The current step 

248 comm (mpi4py.MPI.Intracomm): Communicator 

249 

250 Returns: 

251 None 

252 """ 

253 crash_now = False 

254 

255 if S.status.first: 

256 # check if we performed too many restarts 

257 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts 

258 self.buffers.restart_earlier = False # there is no earlier step 

259 

260 if self.buffers.max_restart_reached and S.status.restart: 

261 if self.params.crash_after_max_restarts: 

262 crash_now = True 

263 self.log( 

264 f"Step(s) restarted {S.status.restarts_in_a_row} \ 

265 time(s) already, maximum reached, moving \ 

266on...", 

267 S, 

268 ) 

269 elif not S.status.prev_done and not self.params.restart_from_first_step: 

270 # receive information about restarts from earlier ranks 

271 buff = np.empty(3, dtype=bool) 

272 self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL]) 

273 self.buffers.restart_earlier = buff[0] 

274 self.buffers.max_restart_reached = buff[1] 

275 crash_now = buff[2] 

276 

277 # decide whether to restart 

278 S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached 

279 

280 # send information about restarts forward 

281 if not S.status.last and not self.params.restart_from_first_step: 

282 buff = np.empty(3, dtype=bool) 

283 buff[0] = S.status.restart 

284 buff[1] = self.buffers.max_restart_reached 

285 buff[2] = crash_now 

286 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL]) 

287 

288 if self.params.restart_from_first_step: 

289 max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0) 

290 S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached 

291 

292 if crash_now: 

293 raise ConvergenceError("Surrendering because of too many restarts...") 

294 

295 return None 

296 

297 def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs): 

298 """ 

299 Update restarts in a row for all steps. 

300 

301 Args: 

302 controller (pySDC.Controller): The controller 

303 S (pySDC.Step): The current step 

304 size (int): The number of ranks 

305 time (list): List containing the time of all the steps 

306 Tend (float): Final time of the simulation 

307 comm (mpi4py.MPI.Intracomm): Communicator 

308 

309 Returns: 

310 None 

311 """ 

312 

313 restart_from = min(comm.allgather(S.status.slot if S.status.restart else S.status.time_size - 1)) 

314 

315 # send "backward" the number of restarts in a row 

316 if S.status.slot >= restart_from: 

317 buff = np.empty(1, dtype=int) 

318 buff[0] = int(S.status.restarts_in_a_row + 1 if S.status.restart else 0) 

319 self.Send( 

320 comm, 

321 dest=S.status.slot - restart_from, 

322 buffer=[buff, self.MPI_INT], 

323 blocking=False, 

324 ) 

325 

326 # receive new number of restarts in a row 

327 if S.status.slot + restart_from < size: 

328 buff = np.empty(1, dtype=int) 

329 self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT]) 

330 S.status.restarts_in_a_row = buff[0] 

331 else: 

332 S.status.restarts_in_a_row = 0 

333 

334 return None