Coverage for pySDC/implementations/convergence_controller_classes/basic_restarting.py: 95%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1from pySDC.core.convergence_controller import ConvergenceController, Pars 

2from pySDC.implementations.convergence_controller_classes.spread_step_sizes import ( 

3 SpreadStepSizesBlockwise, 

4) 

5from pySDC.core.errors import ConvergenceError 

6import numpy as np 

7 

8 

9class BasicRestarting(ConvergenceController): 

10 """ 

11 Class with some utilities for restarting. The specific functions are: 

12 - Telling each step after one that requested a restart to get restarted as well 

13 - Allowing each step to be restarted a limited number of times in a row before just moving on anyways 

14 

15 Default control order is 95. 

16 """ 

17 

18 @classmethod 

19 def get_implementation(cls, useMPI): 

20 """ 

21 Retrieve the implementation for a specific flavor of this class. 

22 

23 Args: 

24 useMPI (bool): Whether or not the controller uses MPI 

25 

26 Returns: 

27 cls: The child class that implements the desired flavor 

28 """ 

29 if useMPI: 

30 return BasicRestartingMPI 

31 else: 

32 return BasicRestartingNonMPI 

33 

34 def __init__(self, controller, params, description, **kwargs): 

35 """ 

36 Initialization routine 

37 

38 Args: 

39 controller (pySDC.Controller): The controller 

40 params (dict): Parameters for the convergence controller 

41 description (dict): The description object used to instantiate the controller 

42 """ 

43 super().__init__(controller, params, description) 

44 self.buffers = Pars({"restart": False, "max_restart_reached": False}) 

45 

46 def setup(self, controller, params, description, **kwargs): 

47 """ 

48 Define parameters here. 

49 

50 Default parameters are: 

51 - control_order (int): The order relative to other convergence controllers 

52 - max_restarts (int): Maximum number of restarts we allow each step before we just move on with whatever we 

53 have 

54 - step_size_spreader (pySDC.ConvergenceController): A convergence controller that takes care of distributing 

55 the steps sizes between blocks 

56 

57 Args: 

58 controller (pySDC.Controller): The controller 

59 params (dict): The params passed for this specific convergence controller 

60 description (dict): The description object used to instantiate the controller 

61 

62 Returns: 

63 (dict): The updated params dictionary 

64 """ 

65 defaults = { 

66 "control_order": 95, 

67 "max_restarts": 10, 

68 "crash_after_max_restarts": True, 

69 "restart_from_first_step": False, 

70 "step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']), 

71 } 

72 

73 from pySDC.implementations.hooks.log_restarts import LogRestarts 

74 

75 controller.add_hook(LogRestarts) 

76 

77 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

78 

79 def setup_status_variables(self, *args, **kwargs): 

80 """ 

81 Add status variables for whether to restart now and how many times the step has been restarted in a row to the 

82 Steps 

83 

84 Returns: 

85 None 

86 """ 

87 self.add_status_variable_to_step('restart', False) 

88 self.add_status_variable_to_step('restarts_in_a_row', 0) 

89 

90 def reset_status_variables(self, *args, **kwargs): 

91 """ 

92 Add status variables for whether to restart now and how many times the step has been restarted in a row to the 

93 Steps 

94 

95 Returns: 

96 None 

97 """ 

98 self.set_step_status_variable('restart', False) 

99 

100 def dependencies(self, controller, description, **kwargs): 

101 """ 

102 Load a convergence controller that spreads the step sizes between steps. 

103 

104 Args: 

105 controller (pySDC.Controller): The controller 

106 description (dict): The description object used to instantiate the controller 

107 

108 Returns: 

109 None 

110 """ 

111 spread_step_sizes_params = { 

112 'spread_from_first_restarted': not self.params.restart_from_first_step, 

113 } 

114 controller.add_convergence_controller( 

115 self.params.step_size_spreader, description=description, params=spread_step_sizes_params 

116 ) 

117 return None 

118 

119 def determine_restart(self, controller, S, **kwargs): 

120 """ 

121 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

122 with the restarts and want to move on anyways. 

123 

124 Args: 

125 controller (pySDC.Controller): The controller 

126 S (pySDC.Step): The current step 

127 

128 Returns: 

129 None 

130 """ 

131 raise NotImplementedError("Please implement a function to determine if we need a restart here!") 

132 

133 

134class BasicRestartingNonMPI(BasicRestarting): 

135 """ 

136 Non-MPI specific version of basic restarting 

137 """ 

138 

139 def reset_buffers_nonMPI(self, controller, **kwargs): 

140 """ 

141 Reset all variables with are used to simulate communication here 

142 

143 Args: 

144 controller (pySDC.Controller): The controller 

145 

146 Returns: 

147 None 

148 """ 

149 self.buffers.restart = False 

150 self.buffers.max_restart_reached = False 

151 

152 return None 

153 

154 def determine_restart(self, controller, S, MS, **kwargs): 

155 """ 

156 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

157 with the restarts and want to move on anyways. 

158 

159 Args: 

160 controller (pySDC.Controller): The controller 

161 S (pySDC.Step): The current step 

162 MS (list): List of active steps 

163 

164 Returns: 

165 None 

166 """ 

167 # check if we performed too many restarts 

168 if S.status.first: 

169 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts 

170 

171 if self.buffers.max_restart_reached and S.status.restart: 

172 if self.params.crash_after_max_restarts: 

173 raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.") 

174 self.log( 

175 f"Step(s) restarted {S.status.restarts_in_a_row} time(s) already, maximum reached, moving \ 

176on...", 

177 S, 

178 ) 

179 

180 self.buffers.restart = S.status.restart or self.buffers.restart 

181 S.status.restart = (S.status.restart or self.buffers.restart) and not self.buffers.max_restart_reached 

182 

183 if S.status.last and self.params.restart_from_first_step and not self.buffers.max_restart_reached: 

184 for step in MS: 

185 step.status.restart = self.buffers.restart 

186 

187 return None 

188 

189 def prepare_next_block(self, controller, S, size, time, Tend, MS, **kwargs): 

190 """ 

191 Update restarts in a row for all steps. 

192 

193 Args: 

194 controller (pySDC.Controller): The controller 

195 S (pySDC.Step): The current step 

196 size (int): The number of ranks 

197 time (list): List containing the time of all the steps 

198 Tend (float): Final time of the simulation 

199 MS (list): List of active steps 

200 

201 Returns: 

202 None 

203 """ 

204 if S not in MS: 

205 return None 

206 

207 restart_from = min([me.status.slot for me in MS if me.status.restart] + [size - 1]) 

208 

209 if S.status.slot < restart_from: 

210 MS[restart_from - S.status.slot].status.restarts_in_a_row = 0 

211 else: 

212 step = MS[S.status.slot - restart_from] 

213 step.status.restarts_in_a_row = S.status.restarts_in_a_row + 1 if S.status.restart else 0 

214 

215 return None 

216 

217 

218class BasicRestartingMPI(BasicRestarting): 

219 """ 

220 MPI specific version of basic restarting 

221 """ 

222 

223 def __init__(self, controller, params, description, **kwargs): 

224 """ 

225 Initialization routine. Adds a buffer. 

226 

227 Args: 

228 controller (pySDC.Controller): The controller 

229 params (dict): Parameters for the convergence controller 

230 description (dict): The description object used to instantiate the controller 

231 """ 

232 from mpi4py import MPI 

233 

234 self.OR = MPI.LOR 

235 

236 super().__init__(controller, params, description) 

237 self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False}) 

238 

239 def determine_restart(self, controller, S, comm, **kwargs): 

240 """ 

241 Restart all steps after the first one which wants to be restarted as well, but also check if we lost patience 

242 with the restarts and want to move on anyways. 

243 

244 Args: 

245 controller (pySDC.Controller): The controller 

246 S (pySDC.Step): The current step 

247 comm (mpi4py.MPI.Intracomm): Communicator 

248 

249 Returns: 

250 None 

251 """ 

252 crash_now = False 

253 

254 if S.status.first: 

255 # check if we performed too many restarts 

256 self.buffers.max_restart_reached = S.status.restarts_in_a_row >= self.params.max_restarts 

257 self.buffers.restart_earlier = False # there is no earlier step 

258 

259 if self.buffers.max_restart_reached and S.status.restart: 

260 if self.params.crash_after_max_restarts: 

261 crash_now = True 

262 self.log( 

263 f"Step(s) restarted {S.status.restarts_in_a_row} time(s) already, maximum reached, moving \ 

264on...", 

265 S, 

266 ) 

267 elif not S.status.prev_done and not self.params.restart_from_first_step: 

268 # receive information about restarts from earlier ranks 

269 buff = np.empty(3, dtype=bool) 

270 self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL]) 

271 self.buffers.restart_earlier = buff[0] 

272 self.buffers.max_restart_reached = buff[1] 

273 crash_now = buff[2] 

274 

275 # decide whether to restart 

276 S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached 

277 

278 # send information about restarts forward 

279 if not S.status.last and not self.params.restart_from_first_step: 

280 buff = np.empty(3, dtype=bool) 

281 buff[0] = S.status.restart 

282 buff[1] = self.buffers.max_restart_reached 

283 buff[2] = crash_now 

284 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL]) 

285 

286 if self.params.restart_from_first_step: 

287 max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0) 

288 S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached 

289 

290 if crash_now: 

291 raise ConvergenceError("Surrendering because of too many restarts...") 

292 

293 return None 

294 

295 def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs): 

296 """ 

297 Update restarts in a row for all steps. 

298 

299 Args: 

300 controller (pySDC.Controller): The controller 

301 S (pySDC.Step): The current step 

302 size (int): The number of ranks 

303 time (list): List containing the time of all the steps 

304 Tend (float): Final time of the simulation 

305 comm (mpi4py.MPI.Intracomm): Communicator 

306 

307 Returns: 

308 None 

309 """ 

310 

311 restart_from = min(comm.allgather(S.status.slot if S.status.restart else S.status.time_size - 1)) 

312 

313 # send "backward" the number of restarts in a row 

314 if S.status.slot >= restart_from: 

315 buff = np.empty(1, dtype=int) 

316 buff[0] = int(S.status.restarts_in_a_row + 1 if S.status.restart else 0) 

317 self.Send( 

318 comm, 

319 dest=S.status.slot - restart_from, 

320 buffer=[buff, self.MPI_INT], 

321 blocking=False, 

322 ) 

323 

324 # receive new number of restarts in a row 

325 if S.status.slot + restart_from < size: 

326 buff = np.empty(1, dtype=int) 

327 self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT]) 

328 S.status.restarts_in_a_row = buff[0] 

329 else: 

330 S.status.restarts_in_a_row = 0 

331 

332 return None