Coverage for pySDC/implementations/convergence_controller_classes/basic_restarting.py: 97%

101 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1from pySDC.core.ConvergenceController import ConvergenceController, Pars 

2from pySDC.implementations.convergence_controller_classes.spread_step_sizes import ( 

3 SpreadStepSizesBlockwise, 

4) 

5from pySDC.core.Errors import ConvergenceError 

6import numpy as np 

7 

8 

9class BasicRestarting(ConvergenceController): 

10 """ 

11 Class with some utilities for restarting. The specific functions are: 

12 - Telling each step after one that requested a restart to get restarted as well 

13 - Allowing each step to be restarted a limited number of times in a row before just moving on anyways 

14 

15 Default control order is 95. 

16 """ 

17 

18 @classmethod 

19 def get_implementation(cls, useMPI): 

20 """ 

21 Retrieve the implementation for a specific flavor of this class. 

22 

23 Args: 

24 useMPI (bool): Whether or not the controller uses MPI 

25 

26 Returns: 

27 cls: The child class that implements the desired flavor 

28 """ 

29 if useMPI: 

30 return BasicRestartingMPI 

31 else: 

32 return BasicRestartingNonMPI 

33 

34 def __init__(self, controller, params, description, **kwargs): 

35 """ 

36 Initialization routine 

37 

38 Args: 

39 controller (pySDC.Controller): The controller 

40 params (dict): Parameters for the convergence controller 

41 description (dict): The description object used to instantiate the controller 

42 """ 

43 super().__init__(controller, params, description) 

44 self.buffers = Pars({"restart": False, "max_restart_reached": False}) 

45 

46 def setup(self, controller, params, description, **kwargs): 

47 """ 

48 Define parameters here. 

49 

50 Default parameters are: 

51 - control_order (int): The order relative to other convergence controllers 

52 - max_restarts (int): Maximum number of restarts we allow each step before we just move on with whatever we 

53 have 

54 - step_size_spreader (pySDC.ConvergenceController): A convergence controller that takes care of distributing 

55 the steps sizes between blocks 

56 

57 Args: 

58 controller (pySDC.Controller): The controller 

59 params (dict): The params passed for this specific convergence controller 

60 description (dict): The description object used to instantiate the controller 

61 

62 Returns: 

63 (dict): The updated params dictionary 

64 """ 

65 defaults = { 

66 "control_order": 95, 

67 "max_restarts": 10, 

68 "crash_after_max_restarts": True, 

69 "restart_from_first_step": False, 

70 "step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']), 

71 } 

72 

73 from pySDC.implementations.hooks.log_restarts import LogRestarts 

74 

75 controller.add_hook(LogRestarts) 

76 

77 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

78 

79 def setup_status_variables(self, controller, **kwargs): 

80 """ 

81 Add status variables for whether to restart now and how many times the step has been restarted in a row to the 

82 Steps 

83 

84 Args: 

85 controller (pySDC.Controller): The controller 

86 reset (bool): Whether the function is called for the first time or to reset 

87 

88 Returns: 

89 None 

90 """ 

91 where = ["S" if 'comm' in kwargs.keys() else "MS", "status"] 

92 self.add_variable(controller, name='restart', where=where, init=False) 

93 self.add_variable(controller, name='restarts_in_a_row', where=where, init=0) 

94 

95 def reset_status_variables(self, controller, reset=False, **kwargs): 

96 """ 

97 Add status variables for whether to restart now and how many times the step has been restarted in a row to the 

98 Steps 

99 

100 Args: 

101 controller (pySDC.Controller): The controller 

102 reset (bool): Whether the function is called for the first time or to reset 

103 

104 Returns: 

105 None 

106 """ 

107 where = ["S" if 'comm' in kwargs.keys() else "MS", "status"] 

108 self.reset_variable(controller, name='restart', where=where, init=False) 

109 

110 def dependencies(self, controller, description, **kwargs): 

111 """ 

112 Load a convergence controller that spreads the step sizes between steps. 

113 

114 Args: 

115 controller (pySDC.Controller): The controller 

116 description (dict): The description object used to instantiate the controller 

117 

118 Returns: 

119 None 

120 """ 

121 spread_step_sizes_params = { 

122 'spread_from_first_restarted': not self.params.restart_from_first_step, 

123 } 

124 controller.add_convergence_controller( 

125 self.params.step_size_spreader, description=description, params=spread_step_sizes_params 

126 ) 

127 return None 

128 

129 def determine_restart(self, controller, S, **kwargs): 

130 """ 

131 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

132 with the restarts and want to move on anyways. 

133 

134 Args: 

135 controller (pySDC.Controller): The controller 

136 S (pySDC.Step): The current step 

137 

138 Returns: 

139 None 

140 """ 

141 raise NotImplementedError("Please implement a function to determine if we need a restart here!") 

142 

143 

144class BasicRestartingNonMPI(BasicRestarting): 

145 """ 

146 Non-MPI specific version of basic restarting 

147 """ 

148 

149 def reset_buffers_nonMPI(self, controller, **kwargs): 

150 """ 

151 Reset all variables with are used to simulate communication here 

152 

153 Args: 

154 controller (pySDC.Controller): The controller 

155 

156 Returns: 

157 None 

158 """ 

159 self.buffers.restart = False 

160 self.buffers.max_restart_reached = False 

161 

162 return None 

163 

164 def determine_restart(self, controller, S, MS, **kwargs): 

165 """ 

166 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

167 with the restarts and want to move on anyways. 

168 

169 Args: 

170 controller (pySDC.Controller): The controller 

171 S (pySDC.Step): The current step 

172 MS (list): List of active steps 

173 

174 Returns: 

175 None 

176 """ 

177 # check if we performed too many restarts 

178 if S.status.first: 

179 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts 

180 

181 if self.buffers.max_restart_reached and S.status.restart: 

182 if self.params.crash_after_max_restarts: 

183 raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.") 

184 self.log( 

185 f"Step(s) restarted {S.status.restarts_in_a_row} \ 

186 time(s) already, maximum reached, moving \ 

187on...", 

188 S, 

189 ) 

190 

191 self.buffers.restart = S.status.restart or self.buffers.restart 

192 S.status.restart = (S.status.restart or self.buffers.restart) and not self.buffers.max_restart_reached 

193 

194 if S.status.last and self.params.restart_from_first_step and not self.buffers.max_restart_reached: 

195 for step in MS: 

196 step.status.restart = self.buffers.restart 

197 

198 return None 

199 

200 def prepare_next_block(self, controller, S, size, time, Tend, MS, **kwargs): 

201 """ 

202 Update restarts in a row for all steps. 

203 

204 Args: 

205 controller (pySDC.Controller): The controller 

206 S (pySDC.Step): The current step 

207 size (int): The number of ranks 

208 time (list): List containing the time of all the steps 

209 Tend (float): Final time of the simulation 

210 MS (list): List of active steps 

211 

212 Returns: 

213 None 

214 """ 

215 if S not in MS: 

216 return None 

217 

218 restart_from = min([me.status.slot for me in MS if me.status.restart] + [size - 1]) 

219 

220 if S.status.slot < restart_from: 

221 MS[restart_from - S.status.slot].status.restarts_in_a_row = 0 

222 else: 

223 step = MS[S.status.slot - restart_from] 

224 step.status.restarts_in_a_row = S.status.restarts_in_a_row + 1 if S.status.restart else 0 

225 

226 return None 

227 

228 

229class BasicRestartingMPI(BasicRestarting): 

230 """ 

231 MPI specific version of basic restarting 

232 """ 

233 

234 def __init__(self, controller, params, description, **kwargs): 

235 """ 

236 Initialization routine. Adds a buffer. 

237 

238 Args: 

239 controller (pySDC.Controller): The controller 

240 params (dict): Parameters for the convergence controller 

241 description (dict): The description object used to instantiate the controller 

242 """ 

243 from mpi4py import MPI 

244 

245 self.OR = MPI.LOR 

246 

247 super().__init__(controller, params, description) 

248 self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False}) 

249 

250 def determine_restart(self, controller, S, comm, **kwargs): 

251 """ 

252 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

253 with the restarts and want to move on anyways. 

254 

255 Args: 

256 controller (pySDC.Controller): The controller 

257 S (pySDC.Step): The current step 

258 comm (mpi4py.MPI.Intracomm): Communicator 

259 

260 Returns: 

261 None 

262 """ 

263 crash_now = False 

264 

265 if S.status.first: 

266 # check if we performed too many restarts 

267 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts 

268 self.buffers.restart_earlier = False # there is no earlier step 

269 

270 if self.buffers.max_restart_reached and S.status.restart: 

271 if self.params.crash_after_max_restarts: 

272 crash_now = True 

273 self.log( 

274 f"Step(s) restarted {S.status.restarts_in_a_row} \ 

275 time(s) already, maximum reached, moving \ 

276on...", 

277 S, 

278 ) 

279 elif not S.status.prev_done and not self.params.restart_from_first_step: 

280 # receive information about restarts from earlier ranks 

281 buff = np.empty(3, dtype=bool) 

282 self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL]) 

283 self.buffers.restart_earlier = buff[0] 

284 self.buffers.max_restart_reached = buff[1] 

285 crash_now = buff[2] 

286 

287 # decide whether to restart 

288 S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached 

289 

290 # send information about restarts forward 

291 if not S.status.last and not self.params.restart_from_first_step: 

292 buff = np.empty(3, dtype=bool) 

293 buff[0] = S.status.restart 

294 buff[1] = self.buffers.max_restart_reached 

295 buff[2] = crash_now 

296 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL]) 

297 

298 if self.params.restart_from_first_step: 

299 max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0) 

300 S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached 

301 

302 if crash_now: 

303 raise ConvergenceError("Surrendering because of too many restarts...") 

304 

305 return None 

306 

307 def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs): 

308 """ 

309 Update restarts in a row for all steps. 

310 

311 Args: 

312 controller (pySDC.Controller): The controller 

313 S (pySDC.Step): The current step 

314 size (int): The number of ranks 

315 time (list): List containing the time of all the steps 

316 Tend (float): Final time of the simulation 

317 comm (mpi4py.MPI.Intracomm): Communicator 

318 

319 Returns: 

320 None 

321 """ 

322 

323 restart_from = min(comm.allgather(S.status.slot if S.status.restart else S.status.time_size - 1)) 

324 

325 # send "backward" the number of restarts in a row 

326 if S.status.slot >= restart_from: 

327 buff = np.empty(1, dtype=int) 

328 buff[0] = int(S.status.restarts_in_a_row + 1 if S.status.restart else 0) 

329 self.Send( 

330 comm, 

331 dest=S.status.slot - restart_from, 

332 buffer=[buff, self.MPI_INT], 

333 blocking=False, 

334 ) 

335 

336 # receive new number of restarts in a row 

337 if S.status.slot + restart_from < size: 

338 buff = np.empty(1, dtype=int) 

339 self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT]) 

340 S.status.restarts_in_a_row = buff[0] 

341 else: 

342 S.status.restarts_in_a_row = 0 

343 

344 return None