Coverage for pySDC/implementations/controller_classes/controller_MPI.py: 66%

387 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-04 15:08 +0000

1import numpy as np 

2from mpi4py import MPI 

3 

4from pySDC.core.controller import Controller 

5from pySDC.core.errors import ControllerError 

6from pySDC.core.step import Step 

7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting 

8 

9 

10class controller_MPI(Controller): 

11 """ 

12 

13 PFASST controller, running parallel version of PFASST in blocks (MG-style) 

14 

15 """ 

16 

17 def __init__(self, controller_params, description, comm): 

18 """ 

19 Initialization routine for PFASST controller 

20 

21 Args: 

22 controller_params: parameter set for the controller and the step class 

23 description: all the parameters to set up the rest (levels, problems, transfer, ...) 

24 comm: MPI communicator 

25 """ 

26 

27 # call parent's initialization routine 

28 super().__init__(controller_params, description, useMPI=True) 

29 

30 # create single step per processor 

31 self.S: Step = Step(description) 

32 

33 # pass communicator for future use 

34 self.comm = comm 

35 

36 num_procs = self.comm.Get_size() 

37 rank = self.comm.Get_rank() 

38 

39 # insert data on time communicator to the steps (helpful here and there) 

40 self.S.status.time_size = num_procs 

41 

42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=True)] 

43 for convergence_controller in self.base_convergence_controllers: 

44 self.add_convergence_controller(convergence_controller, description) 

45 

46 if self.params.dump_setup and rank == 0: 

47 self.dump_setup(step=self.S, controller_params=controller_params, description=description) 

48 

49 num_levels = len(self.S.levels) 

50 

51 # add request handler for status send 

52 self.req_status = None 

53 # add request handle container for isend 

54 self.req_send = [None] * num_levels 

55 self.req_ibcast = None 

56 self.req_diff = None 

57 

58 if num_procs > 1 and num_levels > 1: 

59 for L in self.S.levels: 

60 if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update: 

61 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k") 

62 

63 if num_levels == 1 and self.params.predict_type is not None: 

64 self.logger.warning( 

65 'you have specified a predictor type but only a single level.. predictor will be ignored' 

66 ) 

67 

68 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

69 C.setup_status_variables(self, comm=comm) 

70 

71 def run(self, u0, t0, Tend): 

72 """ 

73 Main driver for running the parallel version of SDC, MSSDC, MLSDC and PFASST 

74 

75 Args: 

76 u0: initial values 

77 t0: starting time 

78 Tend: ending time 

79 

80 Returns: 

81 end values on the finest level 

82 stats object containing statistics for each step, each level and each iteration 

83 """ 

84 

85 # reset stats to prevent double entries from old runs 

86 for hook in self.hooks: 

87 hook.reset_stats() 

88 

89 # setup time initially 

90 all_dt = self.comm.allgather(self.S.dt) 

91 time = t0 + sum(all_dt[: self.comm.rank]) 

92 

93 active = time < Tend - 10 * np.finfo(float).eps 

94 comm_active = self.comm.Split(active) 

95 self.S.status.slot = comm_active.rank 

96 

97 if self.comm.rank == 0 and not active: 

98 raise ControllerError('Nothing to do, check t0, dt and Tend!') 

99 

100 # initialize block of steps with u0 

101 self.restart_block(comm_active.size, time, u0, comm=comm_active) 

102 uend = u0 

103 

104 # call post-setup hook 

105 for hook in self.hooks: 

106 hook.post_setup(step=None, level_number=None) 

107 

108 # call pre-run hook 

109 for hook in self.hooks: 

110 hook.pre_run(step=self.S, level_number=0) 

111 

112 comm_active.Barrier() 

113 

114 # while any process still active... 

115 while active: 

116 while not self.S.status.done: 

117 self.pfasst(comm_active, comm_active.size) 

118 

119 # determine where to restart 

120 restarts = comm_active.allgather(self.S.status.restart) 

121 

122 # communicate time and solution to be used as next initial conditions 

123 if True in restarts: 

124 restart_at = np.where(restarts)[0][0] 

125 uend = self.S.levels[0].u[0].bcast(root=restart_at, comm=comm_active) 

126 tend = comm_active.bcast(self.S.time, root=restart_at) 

127 self.logger.info(f'Starting next block with initial conditions from step {restart_at}') 

128 

129 else: 

130 uend = self.S.levels[0].uend.bcast(root=comm_active.size - 1, comm=comm_active) 

131 tend = comm_active.bcast(self.S.time + self.S.dt, root=comm_active.size - 1) 

132 

133 # do convergence controller stuff 

134 if not self.S.status.restart: 

135 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

136 C.post_step_processing(self, self.S, comm=comm_active) 

137 

138 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

139 C.prepare_next_block(self, self.S, self.S.status.time_size, tend, Tend, comm=comm_active) 

140 

141 # set new time 

142 all_dt = comm_active.allgather(self.S.dt) 

143 time = tend + sum(all_dt[: self.S.status.slot]) 

144 

145 active = time < Tend - 10 * np.finfo(float).eps 

146 

147 # check if we need to split the communicator 

148 if tend + sum(all_dt[: comm_active.size - 1]) >= Tend - 10 * np.finfo(float).eps: 

149 comm_active_new = comm_active.Split(active) 

150 comm_active.Free() 

151 comm_active = comm_active_new 

152 

153 self.S.status.slot = comm_active.rank 

154 

155 # initialize block of steps with u0 

156 if active: 

157 self.restart_block(comm_active.size, time, uend, comm=comm_active) 

158 

159 # call post-run hook 

160 for hook in self.hooks: 

161 hook.post_run(step=self.S, level_number=0) 

162 

163 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

164 C.post_run_processing(self, self.S, comm=self.comm) 

165 

166 comm_active.Free() 

167 

168 return uend, self.return_stats() 

169 

170 def restart_block(self, size, time, u0, comm): 

171 """ 

172 Helper routine to reset/restart block of (active) steps 

173 

174 Args: 

175 size: number of active time steps 

176 time: current time 

177 u0: initial value to distribute across the steps 

178 comm: the communicator 

179 

180 Returns: 

181 block of (all) steps 

182 """ 

183 

184 # store link to previous step 

185 self.S.prev = (self.S.status.slot - 1) % size 

186 self.S.next = (self.S.status.slot + 1) % size 

187 

188 # resets step 

189 self.S.reset_step() 

190 # determine whether I am the first and/or last in line 

191 self.S.status.first = self.S.prev == size - 1 

192 self.S.status.last = self.S.next == 0 

193 # initialize step with u0 

194 self.S.init_step(u0) 

195 # reset some values 

196 self.S.status.done = False 

197 self.S.status.iter = 0 

198 self.S.status.stage = 'SPREAD' 

199 for l in self.S.levels: 

200 l.tag = None 

201 self.req_status = None 

202 self.req_diff = None 

203 self.req_ibcast = None 

204 self.req_diff = None 

205 self.req_send = [None] * len(self.S.levels) 

206 self.S.status.prev_done = False 

207 self.S.status.force_done = False 

208 

209 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

210 C.reset_status_variables(self, comm=comm) 

211 

212 self.S.status.time_size = size 

213 

214 for lvl in self.S.levels: 

215 lvl.status.time = time 

216 lvl.status.sweep = 1 

217 

218 def recv(self, target, source, tag=None, comm=None): 

219 """ 

220 Receive function 

221 

222 Args: 

223 target: level which will receive the values 

224 source: level which initiated the send 

225 tag: identifier to check if this message is really for me 

226 comm: communicator 

227 """ 

228 req = target.u[0].irecv(source=source, tag=tag, comm=comm) 

229 self.wait_with_interrupt(request=req) 

230 if self.S.status.force_done: 

231 return None 

232 # re-evaluate f on left interval boundary 

233 target.f[0] = target.prob.eval_f(target.u[0], target.time) 

234 

235 def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False): 

236 """ 

237 Function to perform the send, including bookkeeping and logging 

238 

239 Args: 

240 comm: the communicator 

241 blocking: flag to indicate that we need blocking communication 

242 level: the level number 

243 add_to_stats: a flag to end recording data in the hooks (defaults to False) 

244 """ 

245 for hook in self.hooks: 

246 hook.pre_comm(step=self.S, level_number=level) 

247 

248 if not blocking: 

249 self.wait_with_interrupt(request=self.req_send[level]) 

250 if self.S.status.force_done: 

251 return None 

252 

253 self.S.levels[level].sweep.compute_end_point() 

254 

255 if not self.S.status.last: 

256 self.logger.debug( 

257 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' 

258 % ( 

259 self.S.status.slot, 

260 self.S.status.stage, 

261 self.S.time, 

262 self.S.next, 

263 level * 100 + self.S.status.iter, 

264 self.S.status.iter, 

265 ) 

266 ) 

267 self.req_send[level] = self.S.levels[level].uend.isend( 

268 dest=self.S.next, tag=level * 100 + self.S.status.iter, comm=comm 

269 ) 

270 if blocking: 

271 self.wait_with_interrupt(request=self.req_send[level]) 

272 if self.S.status.force_done: 

273 return None 

274 

275 for hook in self.hooks: 

276 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) 

277 

278 def recv_full(self, comm, level=None, add_to_stats=False): 

279 """ 

280 Function to perform the recv, including bookkeeping and logging 

281 

282 Args: 

283 comm: the communicator 

284 level: the level number 

285 add_to_stats: a flag to end recording data in the hooks (defaults to False) 

286 """ 

287 

288 for hook in self.hooks: 

289 hook.pre_comm(step=self.S, level_number=level) 

290 if not self.S.status.first and not self.S.status.prev_done: 

291 self.logger.debug( 

292 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' 

293 % ( 

294 self.S.status.slot, 

295 self.S.status.stage, 

296 self.S.time, 

297 self.S.prev, 

298 level * 100 + self.S.status.iter, 

299 self.S.status.iter, 

300 ) 

301 ) 

302 self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm) 

303 

304 for hook in self.hooks: 

305 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) 

306 

307 def wait_with_interrupt(self, request): 

308 """ 

309 Wrapper for waiting for the completion of a non-blocking communication, can be interrupted 

310 

311 Args: 

312 request: request to wait for 

313 """ 

314 if request is not None and self.req_ibcast is not None: 

315 while not request.Test(): 

316 if self.req_ibcast.Test(): 

317 self.logger.debug(f'{self.S.status.slot} has been cancelled during {self.S.status.stage}..') 

318 self.S.status.stage = f'CANCELLED_{self.S.status.stage}' 

319 self.S.status.force_done = True 

320 return None 

321 if request is not None: 

322 request.Wait() 

323 

324 def check_iteration_estimate(self, comm): 

325 """ 

326 Routine to compute and check error/iteration estimation 

327 

328 Args: 

329 comm: time-communicator 

330 """ 

331 

332 # Compute diff between old and new values 

333 diff_new = 0.0 

334 L = self.S.levels[0] 

335 

336 for m in range(1, L.sweep.coll.num_nodes + 1): 

337 diff_new = max(diff_new, abs(L.uold[m] - L.u[m])) 

338 

339 # Send forward diff 

340 for hook in self.hooks: 

341 hook.pre_comm(step=self.S, level_number=0) 

342 

343 self.wait_with_interrupt(request=self.req_diff) 

344 if self.S.status.force_done: 

345 return None 

346 

347 if not self.S.status.first: 

348 prev_diff = np.empty(1, dtype=float) 

349 req = comm.Irecv((prev_diff, MPI.DOUBLE), source=self.S.prev, tag=999) 

350 self.wait_with_interrupt(request=req) 

351 if self.S.status.force_done: 

352 return None 

353 self.logger.debug( 

354 'recv diff: status %s, process %s, time %s, source %s, tag %s, iter %s' 

355 % (prev_diff, self.S.status.slot, self.S.time, self.S.prev, 999, self.S.status.iter) 

356 ) 

357 diff_new = max(prev_diff[0], diff_new) 

358 

359 if not self.S.status.last: 

360 self.logger.debug( 

361 'isend diff: status %s, process %s, time %s, target %s, tag %s, iter %s' 

362 % (diff_new, self.S.status.slot, self.S.time, self.S.next, 999, self.S.status.iter) 

363 ) 

364 tmp = np.array(diff_new, dtype=float) 

365 self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999) 

366 

367 for hook in self.hooks: 

368 hook.post_comm(step=self.S, level_number=0) 

369 

370 # Store values from first iteration 

371 if self.S.status.iter == 1: 

372 self.S.status.diff_old_loc = diff_new 

373 self.S.status.diff_first_loc = diff_new 

374 # Compute iteration estimate 

375 elif self.S.status.iter > 1: 

376 Ltilde_loc = min(diff_new / self.S.status.diff_old_loc, 0.9) 

377 self.S.status.diff_old_loc = diff_new 

378 alpha = 1 / (1 - Ltilde_loc) * self.S.status.diff_first_loc 

379 Kest_loc = np.log(self.S.params.errtol / alpha) / np.log(Ltilde_loc) * 1.05 # Safety factor! 

380 self.logger.debug( 

381 f'LOCAL: {L.time:8.4f}, {self.S.status.iter}: {int(np.ceil(Kest_loc))}, ' 

382 f'{Ltilde_loc:8.6e}, {Kest_loc:8.6e}, ' 

383 f'{Ltilde_loc ** self.S.status.iter * alpha:8.6e}' 

384 ) 

385 Kest_glob = Kest_loc 

386 # If condition is met, send interrupt 

387 if np.ceil(Kest_glob) <= self.S.status.iter: 

388 if self.S.status.last: 

389 self.logger.debug(f'{self.S.status.slot} is done, broadcasting..') 

390 for hook in self.hooks: 

391 hook.pre_comm(step=self.S, level_number=0) 

392 comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait() 

393 for hook in self.hooks: 

394 hook.post_comm(step=self.S, level_number=0, add_to_stats=True) 

395 self.logger.debug(f'{self.S.status.slot} is done, broadcasting done') 

396 self.S.status.done = True 

397 else: 

398 for hook in self.hooks: 

399 hook.pre_comm(step=self.S, level_number=0) 

400 for hook in self.hooks: 

401 hook.post_comm(step=self.S, level_number=0, add_to_stats=True) 

402 

403 def pfasst(self, comm, num_procs): 

404 """ 

405 Main function including the stages of SDC, MLSDC and PFASST (the "controller") 

406 

407 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper 

408 

409 Args: 

410 comm: communicator 

411 num_procs (int): number of parallel processes 

412 """ 

413 

414 stage = self.S.status.stage 

415 

416 self.logger.debug(stage + ' - process ' + str(self.S.status.slot)) 

417 

418 # Wait for interrupt, if iteration estimator is used 

419 if self.params.use_iteration_estimator and stage == 'SPREAD' and not self.S.status.last: 

420 done = np.empty(1) 

421 self.req_ibcast = comm.Ibcast((done, MPI.INT), root=comm.Get_size() - 1) 

422 

423 # If interrupt is there, cleanup and finish 

424 if self.params.use_iteration_estimator and not self.S.status.last and self.req_ibcast.Test(): 

425 self.logger.debug(f'{self.S.status.slot} is done..') 

426 self.S.status.done = True 

427 

428 if not stage == 'IT_CHECK': 

429 self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..') 

430 self.S.levels[0].u[1:] = self.S.levels[0].uold[1:] 

431 

432 for hook in self.hooks: 

433 hook.post_iteration(step=self.S, level_number=0) 

434 

435 for req in self.req_send: 

436 if req is not None and req != MPI.REQUEST_NULL: 

437 req.Cancel() 

438 if self.req_status is not None and self.req_status != MPI.REQUEST_NULL: 

439 self.req_status.Cancel() 

440 if self.req_diff is not None and self.req_diff != MPI.REQUEST_NULL: 

441 self.req_diff.Cancel() 

442 

443 self.S.status.stage = 'DONE' 

444 for hook in self.hooks: 

445 hook.post_step(step=self.S, level_number=0) 

446 

447 else: 

448 # Start cycling, if not interrupted 

449 switcher = { 

450 'SPREAD': self.spread, 

451 'PREDICT': self.predict, 

452 'IT_CHECK': self.it_check, 

453 'IT_FINE': self.it_fine, 

454 'IT_DOWN': self.it_down, 

455 'IT_COARSE': self.it_coarse, 

456 'IT_UP': self.it_up, 

457 } 

458 

459 switcher.get(stage, self.default)(comm, num_procs) 

460 

461 def spread(self, comm, num_procs): 

462 """ 

463 Spreading phase 

464 """ 

465 

466 # first stage: spread values 

467 for hook in self.hooks: 

468 hook.pre_step(step=self.S, level_number=0) 

469 

470 # call predictor from sweeper 

471 self.S.levels[0].sweep.predict() 

472 

473 if self.params.use_iteration_estimator: 

474 # store previous iterate to compute difference later on 

475 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:] 

476 

477 # update stage 

478 if len(self.S.levels) > 1: # MLSDC or PFASST with predict 

479 self.S.status.stage = 'PREDICT' 

480 else: 

481 self.S.status.stage = 'IT_CHECK' 

482 

483 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

484 C.post_spread_processing(self, self.S, comm=comm) 

485 

486 def predict(self, comm, num_procs): 

487 """ 

488 Predictor phase 

489 """ 

490 

491 for hook in self.hooks: 

492 hook.pre_predict(step=self.S, level_number=0) 

493 

494 if self.params.predict_type is None: 

495 pass 

496 

497 elif self.params.predict_type == 'fine_only': 

498 # do a fine sweep only 

499 self.S.levels[0].sweep.update_nodes() 

500 

501 # elif self.params.predict_type == 'libpfasst_style': 

502 # 

503 # # restrict to coarsest level 

504 # for l in range(1, len(self.S.levels)): 

505 # self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l]) 

506 # 

507 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1) 

508 # if not self.S.status.first: 

509 # self.logger.debug('recv data predict: process %s, stage %s, time, %s, source %s, tag %s' % 

510 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, 

511 # self.S.status.iter)) 

512 # self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm) 

513 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1) 

514 # 

515 # # do the sweep with new values 

516 # self.S.levels[-1].sweep.update_nodes() 

517 # self.S.levels[-1].sweep.compute_end_point() 

518 # 

519 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1) 

520 # if not self.S.status.last: 

521 # self.logger.debug('send data predict: process %s, stage %s, time, %s, target %s, tag %s' % 

522 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, 

523 # self.S.status.iter)) 

524 # self.S.levels[-1].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm).Wait() 

525 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True) 

526 # 

527 # # go back to fine level, sweeping 

528 # for l in range(len(self.S.levels) - 1, 0, -1): 

529 # # prolong values 

530 # self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

531 # # on middle levels: do sweep as usual 

532 # if l - 1 > 0: 

533 # self.S.levels[l - 1].sweep.update_nodes() 

534 # 

535 # # end with a fine sweep 

536 # self.S.levels[0].sweep.update_nodes() 

537 

538 elif self.params.predict_type == 'pfasst_burnin': 

539 # restrict to coarsest level 

540 for l in range(1, len(self.S.levels)): 

541 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l]) 

542 

543 for p in range(self.S.status.slot + 1): 

544 if not p == 0: 

545 self.recv_full(comm=comm, level=len(self.S.levels) - 1) 

546 if self.S.status.force_done: 

547 return None 

548 

549 # do the sweep with new values 

550 self.S.levels[-1].sweep.update_nodes() 

551 self.S.levels[-1].sweep.compute_end_point() 

552 

553 self.send_full( 

554 comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=(p == self.S.status.slot) 

555 ) 

556 if self.S.status.force_done: 

557 return None 

558 

559 # interpolate back to finest level 

560 for l in range(len(self.S.levels) - 1, 0, -1): 

561 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

562 

563 self.send_full(comm=comm, level=0) 

564 if self.S.status.force_done: 

565 return None 

566 

567 self.recv_full(comm=comm, level=0) 

568 if self.S.status.force_done: 

569 return None 

570 

571 # end this with a fine sweep 

572 self.S.levels[0].sweep.update_nodes() 

573 

574 elif self.params.predict_type == 'fmg': 

575 # TODO: implement FMG predictor 

576 raise NotImplementedError('FMG predictor is not yet implemented') 

577 

578 else: 

579 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type) 

580 

581 for hook in self.hooks: 

582 hook.post_predict(step=self.S, level_number=0) 

583 

584 # update stage 

585 self.S.status.stage = 'IT_CHECK' 

586 

587 def it_check(self, comm, num_procs): 

588 """ 

589 Key routine to check for convergence/termination 

590 """ 

591 

592 # Update values to compute the residual 

593 self.send_full(comm=comm, level=0) 

594 if self.S.status.force_done: 

595 return None 

596 

597 self.recv_full(comm=comm, level=0) 

598 if self.S.status.force_done: 

599 return None 

600 

601 # compute the residual 

602 self.S.levels[0].sweep.compute_residual(stage='IT_CHECK') 

603 

604 if self.params.use_iteration_estimator: 

605 # TODO: replace with convergence controller 

606 self.check_iteration_estimate(comm=comm) 

607 

608 if self.S.status.force_done: 

609 return None 

610 

611 if self.S.status.iter > 0: 

612 for hook in self.hooks: 

613 hook.post_iteration(step=self.S, level_number=0) 

614 

615 # decide if the step is done, needs to be restarted and other things convergence related 

616 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

617 C.post_iteration_processing(self, self.S, comm=comm) 

618 C.convergence_control(self, self.S, comm=comm) 

619 

620 # if not ready, keep doing stuff 

621 if not self.S.status.done: 

622 # increment iteration count here (and only here) 

623 self.S.status.iter += 1 

624 

625 for hook in self.hooks: 

626 hook.pre_iteration(step=self.S, level_number=0) 

627 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

628 C.pre_iteration_processing(self, self.S, comm=comm) 

629 

630 if self.params.use_iteration_estimator: 

631 # store previous iterate to compute difference later on 

632 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:] 

633 

634 if len(self.S.levels) > 1: # MLSDC or PFASST 

635 self.S.status.stage = 'IT_DOWN' 

636 else: 

637 if num_procs == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like) 

638 self.S.status.stage = 'IT_FINE' 

639 else: 

640 self.S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like) 

641 

642 else: 

643 if not self.params.use_iteration_estimator: 

644 # Need to finish all pending isend requests. These will occur for the first active process, since 

645 # in the last iteration the wait statement will not be called ("send and forget") 

646 for req in self.req_send: 

647 if req is not None: 

648 req.Wait() 

649 if self.req_status is not None: 

650 self.req_status.Wait() 

651 if self.req_diff is not None: 

652 self.req_diff.Wait() 

653 else: 

654 for req in self.req_send: 

655 if req is not None: 

656 req.Cancel() 

657 if self.req_status is not None: 

658 self.req_status.Cancel() 

659 if self.req_diff is not None: 

660 self.req_diff.Cancel() 

661 

662 for hook in self.hooks: 

663 hook.post_step(step=self.S, level_number=0) 

664 self.S.status.stage = 'DONE' 

665 

666 def it_fine(self, comm, num_procs): 

667 """ 

668 Fine sweeps 

669 """ 

670 

671 nsweeps = self.S.levels[0].params.nsweeps 

672 

673 self.S.levels[0].status.sweep = 0 

674 

675 # do fine sweep 

676 for k in range(nsweeps): 

677 self.S.levels[0].status.sweep += 1 

678 

679 # send values forward 

680 self.send_full(comm=comm, level=0) 

681 if self.S.status.force_done: 

682 return None 

683 

684 # recv values from previous 

685 self.recv_full(comm=comm, level=0, add_to_stats=(k == nsweeps - 1)) 

686 if self.S.status.force_done: 

687 return None 

688 

689 for hook in self.hooks: 

690 hook.pre_sweep(step=self.S, level_number=0) 

691 

692 self.S.levels[0].sweep.updateVariableCoeffs(k + 1) # update QDelta coefficients if variable preconditioner 

693 self.S.levels[0].sweep.update_nodes() 

694 self.S.levels[0].sweep.compute_residual(stage='IT_FINE') 

695 

696 for hook in self.hooks: 

697 hook.post_sweep(step=self.S, level_number=0) 

698 

699 # update stage 

700 self.S.status.stage = 'IT_CHECK' 

701 

702 def it_down(self, comm, num_procs): 

703 """ 

704 Go down the hierarchy from finest to coarsest level 

705 """ 

706 

707 self.S.transfer(source=self.S.levels[0], target=self.S.levels[1]) 

708 

709 # sweep and send on middle levels (not on finest, not on coarsest, though) 

710 for l in range(1, len(self.S.levels) - 1): 

711 nsweeps = self.S.levels[l].params.nsweeps 

712 

713 for _ in range(nsweeps): 

714 self.send_full(comm=comm, level=l) 

715 if self.S.status.force_done: 

716 return None 

717 

718 self.recv_full(comm=comm, level=l) 

719 if self.S.status.force_done: 

720 return None 

721 

722 for hook in self.hooks: 

723 hook.pre_sweep(step=self.S, level_number=l) 

724 

725 self.S.levels[l].sweep.update_nodes() 

726 self.S.levels[l].sweep.compute_residual(stage='IT_DOWN') 

727 for hook in self.hooks: 

728 hook.post_sweep(step=self.S, level_number=l) 

729 

730 # transfer further down the hierarchy 

731 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1]) 

732 

733 # update stage 

734 self.S.status.stage = 'IT_COARSE' 

735 

736 def it_coarse(self, comm, num_procs): 

737 """ 

738 Coarse sweep 

739 """ 

740 

741 # receive from previous step (if not first) 

742 self.recv_full(comm=comm, level=len(self.S.levels) - 1) 

743 if self.S.status.force_done: 

744 return None 

745 

746 # do the sweep 

747 for hook in self.hooks: 

748 hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1) 

749 assert self.S.levels[-1].params.nsweeps == 1, ( 

750 'ERROR: this controller can only work with one sweep on the coarse level, got %s' 

751 % self.S.levels[-1].params.nsweeps 

752 ) 

753 self.S.levels[-1].sweep.update_nodes() 

754 self.S.levels[-1].sweep.compute_residual(stage='IT_COARSE') 

755 for hook in self.hooks: 

756 hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1) 

757 self.S.levels[-1].sweep.compute_end_point() 

758 

759 # send to next step 

760 self.send_full(comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=True) 

761 if self.S.status.force_done: 

762 return None 

763 

764 # update stage 

765 if len(self.S.levels) > 1: # MLSDC or PFASST 

766 self.S.status.stage = 'IT_UP' 

767 else: 

768 self.S.status.stage = 'IT_CHECK' # MSSDC 

769 

770 def it_up(self, comm, num_procs): 

771 """ 

772 Prolong corrections up to finest level (parallel) 

773 """ 

774 

775 # receive and sweep on middle levels (except for coarsest level) 

776 for l in range(len(self.S.levels) - 1, 0, -1): 

777 # prolong values 

778 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

779 

780 # on middle levels: do sweep as usual 

781 if l - 1 > 0: 

782 nsweeps = self.S.levels[l - 1].params.nsweeps 

783 

784 for k in range(nsweeps): 

785 self.send_full(comm, level=l - 1) 

786 if self.S.status.force_done: 

787 return None 

788 

789 self.recv_full(comm=comm, level=l - 1, add_to_stats=(k == nsweeps - 1)) 

790 if self.S.status.force_done: 

791 return None 

792 

793 for hook in self.hooks: 

794 hook.pre_sweep(step=self.S, level_number=l - 1) 

795 self.S.levels[l - 1].sweep.update_nodes() 

796 self.S.levels[l - 1].sweep.compute_residual(stage='IT_UP') 

797 for hook in self.hooks: 

798 hook.post_sweep(step=self.S, level_number=l - 1) 

799 

800 # update stage 

801 self.S.status.stage = 'IT_FINE' 

802 

803 def default(self, num_procs): 

804 """ 

805 Default routine to catch wrong status 

806 """ 

807 raise ControllerError('Weird stage, got %s' % self.S.status.stage)