Coverage for pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py: 97%

214 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-21 06:49 +0000

1import itertools 

2import numpy as np 

3 

4from pySDC.core.controller import ParaDiagController 

5from pySDC.core import step as stepclass 

6from pySDC.core.errors import ControllerError 

7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting 

8from pySDC.helpers.ParaDiagHelper import get_G_inv_matrix 

9 

10 

11class controller_ParaDiag_nonMPI(ParaDiagController): 

12 """ 

13 

14 ParaDiag controller, running serialized version. 

15 

16 This controller uses the increment formulation. That is to say, we setup the residual of the all at once problem, 

17 put it on the right hand side, invert the ParaDiag preconditioner on the left-hand side to compute the increment 

18 and then add the increment onto the solution. For this reason, we need to replace the solution values in the steps 

19 with the residual values before the solves and then put the solution plus increment back into the steps. This is a 

20 bit counter to what you expect when you access the `u` variable in the levels, but it is mathematically advantageous. 

21 """ 

22 

23 def __init__(self, num_procs, controller_params, description): 

24 """ 

25 Initialization routine for ParaDiag controller 

26 

27 Args: 

28 num_procs: number of parallel time steps (still serial, though), can be 1 

29 controller_params: parameter set for the controller and the steps 

30 description: all the parameters to set up the rest (levels, problems, transfer, ...) 

31 """ 

32 super().__init__(controller_params, description, useMPI=False, n_steps=num_procs) 

33 

34 self.MS = [] 

35 

36 for l in range(num_procs): 

37 G_inv = get_G_inv_matrix(l, num_procs, self.params.alpha, description['sweeper_params']) 

38 description['sweeper_params']['G_inv'] = G_inv 

39 

40 self.MS.append(stepclass.Step(description)) 

41 

42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=False)] 

43 for convergence_controller in self.base_convergence_controllers: 

44 self.add_convergence_controller(convergence_controller, description) 

45 

46 if self.params.dump_setup: 

47 self.dump_setup(step=self.MS[0], controller_params=controller_params, description=description) 

48 

49 if len(self.MS[0].levels) > 1: 

50 raise NotImplementedError('This controller does not support multiple levels') 

51 

52 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

53 C.reset_buffers_nonMPI(self) 

54 C.setup_status_variables(self, MS=self.MS) 

55 

56 def ParaDiag(self, local_MS_active): 

57 """ 

58 Main function for ParaDiag 

59 

60 For the workflow of this controller, see https://arxiv.org/abs/2103.12571 

61 

62 This method changes self.MS directly by accessing active steps through local_MS_active. 

63 

64 Args: 

65 local_MS_active (list): all active steps 

66 

67 Returns: 

68 boot: Whether all steps are done 

69 """ 

70 

71 # if all stages are the same (or DONE), continue, otherwise abort 

72 stages = [S.status.stage for S in local_MS_active if S.status.stage != 'DONE'] 

73 if stages[1:] == stages[:-1]: 

74 stage = stages[0] 

75 else: 

76 raise ControllerError('not all stages are equal') 

77 

78 self.logger.debug(stage) 

79 

80 MS_running = [S for S in local_MS_active if S.status.stage != 'DONE'] 

81 

82 switcher = { 

83 'SPREAD': self.spread, 

84 'IT_CHECK': self.it_check, 

85 'IT_PARADIAG': self.it_ParaDiag, 

86 } 

87 

88 assert stage in switcher.keys(), f'Got unexpected stage {stage!r}' 

89 switcher[stage](MS_running) 

90 

91 return all(S.status.done for S in local_MS_active) 

92 

93 def apply_matrix(self, mat, quantity): 

94 """ 

95 Apply a matrix on the step level. Needs to be square. Puts the result back into the controller. 

96 

97 Args: 

98 mat: square LxL matrix with L number of steps 

99 """ 

100 L = len(self.MS) 

101 assert np.allclose(mat.shape, L) 

102 assert len(mat.shape) == 2 

103 

104 level = self.MS[0].levels[0] 

105 M = level.sweep.params.num_nodes 

106 prob = level.prob 

107 

108 # buffer for storing the result 

109 res = [ 

110 None, 

111 ] * L 

112 

113 if quantity == 'residual': 

114 me = [S.levels[0].residual for S in self.MS] 

115 elif quantity == 'increment': 

116 me = [S.levels[0].increment for S in self.MS] 

117 else: 

118 raise NotImplementedError 

119 

120 # compute matrix-vector product 

121 for i in range(mat.shape[0]): 

122 res[i] = [prob.u_init for _ in range(M)] 

123 for j in range(mat.shape[1]): 

124 for m in range(M): 

125 res[i][m] += mat[i, j] * me[j][m] 

126 

127 # put the result in the "output" 

128 for i in range(mat.shape[0]): 

129 for m in range(M): 

130 me[i][m] = res[i][m] 

131 

132 def compute_all_at_once_residual(self, local_MS_running): 

133 """ 

134 This requires to communicate the solutions at the end of the steps to be the initial conditions for the next 

135 steps. Afterwards, the residual can be computed locally on the steps. 

136 

137 Args: 

138 local_MS_running (list): list of currently running steps 

139 """ 

140 

141 for S in local_MS_running: 

142 # communicate initial conditions 

143 S.levels[0].sweep.compute_end_point() 

144 

145 for hook in self.hooks: 

146 hook.pre_comm(step=S, level_number=0) 

147 

148 if not S.status.first: 

149 S.levels[0].u[0] = S.prev.levels[0].uend 

150 

151 for hook in self.hooks: 

152 hook.post_comm(step=S, level_number=0, add_to_stats=True) 

153 

154 # compute residuals locally 

155 S.levels[0].sweep.compute_residual() 

156 

157 def update_solution(self, local_MS_running): 

158 """ 

159 Since we solve for the increment, we need to update the solution between iterations by adding the increment. 

160 

161 Args: 

162 local_MS_running (list): list of currently running steps 

163 """ 

164 for S in local_MS_running: 

165 for m in range(S.levels[0].sweep.coll.num_nodes): 

166 S.levels[0].u[m + 1] += S.levels[0].increment[m] 

167 

168 def prepare_Jacobians(self, local_MS_running): 

169 # get solutions for constructing average Jacobians 

170 if self.params.average_jacobian: 

171 level = local_MS_running[0].levels[0] 

172 M = level.sweep.coll.num_nodes 

173 

174 u_avg = [level.prob.dtype_u(level.prob.init, val=0)] * M 

175 

176 # communicate average solution 

177 for S in local_MS_running: 

178 for m in range(M): 

179 u_avg[m] += S.levels[0].u[m + 1] / self.n_steps 

180 

181 # store the averaged solution in the steps 

182 for S in local_MS_running: 

183 S.levels[0].u_avg = u_avg 

184 

185 def it_ParaDiag(self, local_MS_running): 

186 """ 

187 Do a single ParaDiag iteration. Does the following steps 

188 - (1) Compute the residual of the all-at-once / composite collocation problem 

189 - (2) Compute an FFT in time to diagonalize the preconditioner 

190 - (3) Solve the collocation problems locally on the steps for the increment 

191 - (4) Compute iFFT in time to go back to the original base 

192 - (5) Update the solution by adding increment 

193 

194 Note that this is the only place where we compute the all-at-once residual because it requires communication and 

195 swaps the solution values for the residuals. So after the residual tolerance is reached, one more ParaDiag 

196 iteration will be done. 

197 

198 Args: 

199 local_MS_running (list): list of currently running steps 

200 """ 

201 

202 for S in local_MS_running: 

203 for hook in self.hooks: 

204 hook.pre_sweep(step=S, level_number=0) 

205 

206 # communicate average residual for setting up Jacobians for non-linear problems 

207 self.prepare_Jacobians(local_MS_running) 

208 

209 # compute the all-at-once residual to use as right hand side 

210 self.compute_all_at_once_residual(local_MS_running) 

211 

212 # weighted FFT of the residual in time 

213 self.FFT_in_time(quantity='residual') 

214 

215 # perform local solves of "collocation problems" on the steps (can be done in parallel) 

216 for S in local_MS_running: 

217 assert len(S.levels) == 1, 'Multi-level SDC not implemented in ParaDiag' 

218 S.levels[0].sweep.update_nodes() 

219 

220 # inverse FFT of the increment in time 

221 self.iFFT_in_time(quantity='increment') 

222 

223 # get the next iterate by adding increment to previous iterate 

224 self.update_solution(local_MS_running) 

225 

226 for S in local_MS_running: 

227 for hook in self.hooks: 

228 hook.post_sweep(step=S, level_number=0) 

229 

230 # update stage 

231 for S in local_MS_running: 

232 S.status.stage = 'IT_CHECK' 

233 

234 def it_check(self, local_MS_running): 

235 """ 

236 Key routine to check for convergence/termination 

237 

238 Args: 

239 local_MS_running (list): list of currently running steps 

240 """ 

241 

242 for S in local_MS_running: 

243 if S.status.iter > 0: 

244 for hook in self.hooks: 

245 hook.post_iteration(step=S, level_number=0) 

246 

247 # decide if the step is done, needs to be restarted and other things convergence related 

248 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

249 C.post_iteration_processing(self, S, MS=local_MS_running) 

250 C.convergence_control(self, S, MS=local_MS_running) 

251 

252 for S in local_MS_running: 

253 if not S.status.first: 

254 for hook in self.hooks: 

255 hook.pre_comm(step=S, level_number=0) 

256 S.status.prev_done = S.prev.status.done # "communicate" 

257 for hook in self.hooks: 

258 hook.post_comm(step=S, level_number=0, add_to_stats=True) 

259 S.status.done = S.status.done and S.status.prev_done 

260 

261 if self.params.all_to_done: 

262 for hook in self.hooks: 

263 hook.pre_comm(step=S, level_number=0) 

264 S.status.done = all(T.status.done for T in local_MS_running) 

265 for hook in self.hooks: 

266 hook.post_comm(step=S, level_number=0, add_to_stats=True) 

267 

268 if not S.status.done: 

269 # increment iteration count here (and only here) 

270 S.status.iter += 1 

271 for hook in self.hooks: 

272 hook.pre_iteration(step=S, level_number=0) 

273 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

274 C.pre_iteration_processing(self, S, MS=local_MS_running) 

275 

276 # Do another ParaDiag iteration 

277 S.status.stage = 'IT_PARADIAG' 

278 else: 

279 S.levels[0].sweep.compute_end_point() 

280 for hook in self.hooks: 

281 hook.post_step(step=S, level_number=0) 

282 S.status.stage = 'DONE' 

283 

284 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

285 C.reset_buffers_nonMPI(self) 

286 

287 def spread(self, local_MS_running): 

288 """ 

289 Spreading phase 

290 

291 Args: 

292 local_MS_running (list): list of currently running steps 

293 """ 

294 

295 for S in local_MS_running: 

296 

297 # first stage: spread values 

298 for hook in self.hooks: 

299 hook.pre_step(step=S, level_number=0) 

300 

301 # call predictor from sweeper 

302 S.levels[0].sweep.predict() 

303 

304 # compute the residual 

305 S.levels[0].sweep.compute_residual() 

306 

307 # update stage 

308 S.status.stage = 'IT_CHECK' 

309 

310 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

311 C.post_spread_processing(self, S, MS=local_MS_running) 

312 

313 def run(self, u0, t0, Tend): 

314 """ 

315 Main driver for running the serial version of ParaDiag 

316 

317 Args: 

318 u0: initial values 

319 t0: starting time 

320 Tend: ending time 

321 

322 Returns: 

323 end values on the last step 

324 stats object containing statistics for each step, each level and each iteration 

325 """ 

326 

327 # some initializations and reset of statistics 

328 uend = None 

329 num_procs = len(self.MS) 

330 for hook in self.hooks: 

331 hook.reset_stats() 

332 

333 # initial ordering of the steps: 0,1,...,Np-1 

334 slots = list(range(num_procs)) 

335 

336 # initialize time variables of each step 

337 time = [t0 + sum(self.MS[j].dt for j in range(p)) for p in slots] 

338 

339 # determine which steps are still active (time < Tend) 

340 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots] 

341 if not all(active) and any(active): 

342 self.logger.warning( 

343 'Warning: This controller will solve past your desired end time until the end of its block!' 

344 ) 

345 active = [ 

346 True, 

347 ] * len(active) 

348 

349 if not any(active): 

350 raise ControllerError('Nothing to do, check t0, dt and Tend.') 

351 

352 # compress slots according to active steps, i.e. remove all steps which have times above Tend 

353 active_slots = list(itertools.compress(slots, active)) 

354 

355 # initialize block of steps with u0 

356 self.restart_block(active_slots, time, u0) 

357 

358 for hook in self.hooks: 

359 hook.post_setup(step=None, level_number=None) 

360 

361 # call pre-run hook 

362 for S in self.MS: 

363 for hook in self.hooks: 

364 hook.pre_run(step=S, level_number=0) 

365 

366 # main loop: as long as at least one step is still active (time < Tend), do something 

367 while any(active): 

368 MS_active = [self.MS[p] for p in active_slots] 

369 done = False 

370 while not done: 

371 done = self.ParaDiag(MS_active) 

372 

373 restarts = [S.status.restart for S in MS_active] 

374 restart_at = np.where(restarts)[0][0] if True in restarts else len(MS_active) 

375 if True in restarts: # restart part of the block 

376 # initial condition to next block is initial condition of step that needs restarting 

377 uend = self.MS[restart_at].levels[0].u[0] 

378 time[active_slots[0]] = time[restart_at] 

379 self.logger.info(f'Starting next block with initial conditions from step {restart_at}') 

380 

381 else: # move on to next block 

382 # initial condition for next block is last solution of current block 

383 uend = self.MS[active_slots[-1]].levels[0].uend 

384 time[active_slots[0]] = time[active_slots[-1]] + self.MS[active_slots[-1]].dt 

385 

386 for S in MS_active[:restart_at]: 

387 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

388 C.post_step_processing(self, S, MS=MS_active) 

389 

390 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

391 [C.prepare_next_block(self, S, len(active_slots), time, Tend, MS=MS_active) for S in self.MS] 

392 

393 # setup the times of the steps for the next block 

394 for i in range(1, len(active_slots)): 

395 time[active_slots[i]] = time[active_slots[i] - 1] + self.MS[active_slots[i] - 1].dt 

396 

397 # determine new set of active steps and compress slots accordingly 

398 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots] 

399 if not all(active) and any(active): 

400 self.logger.warning( 

401 'Warning: This controller will solve past your desired end time until the end of its block!' 

402 ) 

403 active = [ 

404 True, 

405 ] * len(active) 

406 active_slots = list(itertools.compress(slots, active)) 

407 

408 # restart active steps (reset all values and pass uend to u0) 

409 self.restart_block(active_slots, time, uend) 

410 

411 # call post-run hook 

412 for S in self.MS: 

413 for hook in self.hooks: 

414 hook.post_run(step=S, level_number=0) 

415 

416 for S in self.MS: 

417 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

418 C.post_run_processing(self, S, MS=MS_active) 

419 

420 return uend, self.return_stats() 

421 

422 def restart_block(self, active_slots, time, u0): 

423 """ 

424 Helper routine to reset/restart block of (active) steps 

425 

426 Args: 

427 active_slots: list of active steps 

428 time: list of new times 

429 u0: initial value to distribute across the steps 

430 

431 """ 

432 

433 for j in range(len(active_slots)): 

434 # get slot number 

435 p = active_slots[j] 

436 

437 # store current slot number for diagnostics 

438 self.MS[p].status.slot = p 

439 # store link to previous step 

440 self.MS[p].prev = self.MS[active_slots[j - 1]] 

441 

442 self.MS[p].reset_step() 

443 

444 # determine whether I am the first and/or last in line 

445 self.MS[p].status.first = active_slots.index(p) == 0 

446 self.MS[p].status.last = active_slots.index(p) == len(active_slots) - 1 

447 

448 # initialize step with u0 

449 self.MS[p].init_step(u0) 

450 

451 # setup G^{-1} for new number of active slots 

452 # self.MS[j].levels[0].sweep.set_G_inv(get_G_inv_matrix(j, len(active_slots), self.params.alpha, self.description['sweeper_params'])) 

453 

454 # reset some values 

455 self.MS[p].status.done = False 

456 self.MS[p].status.prev_done = False 

457 self.MS[p].status.iter = 0 

458 self.MS[p].status.stage = 'SPREAD' 

459 self.MS[p].status.force_done = False 

460 self.MS[p].status.time_size = len(active_slots) 

461 

462 for l in self.MS[p].levels: 

463 l.tag = None 

464 l.status.sweep = 1 

465 

466 for p in active_slots: 

467 for lvl in self.MS[p].levels: 

468 lvl.status.time = time[p] 

469 

470 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

471 C.reset_status_variables(self, active_slots=active_slots)