Coverage for pySDC/implementations/controller_classes/controller_MPI.py: 66%

386 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import numpy as np 

2from mpi4py import MPI 

3 

4from pySDC.core.controller import Controller 

5from pySDC.core.errors import ControllerError 

6from pySDC.core.step import Step 

7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting 

8 

9 

10class controller_MPI(Controller): 

11 """ 

12 

13 PFASST controller, running parallel version of PFASST in blocks (MG-style) 

14 

15 """ 

16 

17 def __init__(self, controller_params, description, comm): 

18 """ 

19 Initialization routine for PFASST controller 

20 

21 Args: 

22 controller_params: parameter set for the controller and the step class 

23 description: all the parameters to set up the rest (levels, problems, transfer, ...) 

24 comm: MPI communicator 

25 """ 

26 

27 # call parent's initialization routine 

28 super().__init__(controller_params, description, useMPI=True) 

29 

30 # create single step per processor 

31 self.S = Step(description) 

32 

33 # pass communicator for future use 

34 self.comm = comm 

35 

36 num_procs = self.comm.Get_size() 

37 rank = self.comm.Get_rank() 

38 

39 # insert data on time communicator to the steps (helpful here and there) 

40 self.S.status.time_size = num_procs 

41 

42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=True)] 

43 for convergence_controller in self.base_convergence_controllers: 

44 self.add_convergence_controller(convergence_controller, description) 

45 

46 if self.params.dump_setup and rank == 0: 

47 self.dump_setup(step=self.S, controller_params=controller_params, description=description) 

48 

49 num_levels = len(self.S.levels) 

50 

51 # add request handler for status send 

52 self.req_status = None 

53 # add request handle container for isend 

54 self.req_send = [None] * num_levels 

55 self.req_ibcast = None 

56 self.req_diff = None 

57 

58 if num_procs > 1 and num_levels > 1: 

59 for L in self.S.levels: 

60 if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update: 

61 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k") 

62 

63 if num_levels == 1 and self.params.predict_type is not None: 

64 self.logger.warning( 

65 'you have specified a predictor type but only a single level.. predictor will be ignored' 

66 ) 

67 

68 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

69 C.setup_status_variables(self, comm=comm) 

70 

71 def run(self, u0, t0, Tend): 

72 """ 

73 Main driver for running the parallel version of SDC, MSSDC, MLSDC and PFASST 

74 

75 Args: 

76 u0: initial values 

77 t0: starting time 

78 Tend: ending time 

79 

80 Returns: 

81 end values on the finest level 

82 stats object containing statistics for each step, each level and each iteration 

83 """ 

84 

85 # reset stats to prevent double entries from old runs 

86 for hook in self.hooks: 

87 hook.reset_stats() 

88 

89 # setup time initially 

90 all_dt = self.comm.allgather(self.S.dt) 

91 time = t0 + sum(all_dt[: self.comm.rank]) 

92 

93 active = time < Tend - 10 * np.finfo(float).eps 

94 comm_active = self.comm.Split(active) 

95 self.S.status.slot = comm_active.rank 

96 

97 if self.comm.rank == 0 and not active: 

98 raise ControllerError('Nothing to do, check t0, dt and Tend!') 

99 

100 # initialize block of steps with u0 

101 self.restart_block(comm_active.size, time, u0, comm=comm_active) 

102 uend = u0 

103 

104 # call post-setup hook 

105 for hook in self.hooks: 

106 hook.post_setup(step=None, level_number=None) 

107 

108 # call pre-run hook 

109 for hook in self.hooks: 

110 hook.pre_run(step=self.S, level_number=0) 

111 

112 comm_active.Barrier() 

113 

114 # while any process still active... 

115 while active: 

116 while not self.S.status.done: 

117 self.pfasst(comm_active, comm_active.size) 

118 

119 # determine where to restart 

120 restarts = comm_active.allgather(self.S.status.restart) 

121 

122 # communicate time and solution to be used as next initial conditions 

123 if True in restarts: 

124 restart_at = np.where(restarts)[0][0] 

125 uend = self.S.levels[0].u[0].bcast(root=restart_at, comm=comm_active) 

126 tend = comm_active.bcast(self.S.time, root=restart_at) 

127 self.logger.info(f'Starting next block with initial conditions from step {restart_at}') 

128 

129 else: 

130 uend = self.S.levels[0].uend.bcast(root=comm_active.size - 1, comm=comm_active) 

131 tend = comm_active.bcast(self.S.time + self.S.dt, root=comm_active.size - 1) 

132 

133 # do convergence controller stuff 

134 if not self.S.status.restart: 

135 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

136 C.post_step_processing(self, self.S, comm=comm_active) 

137 

138 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

139 C.prepare_next_block(self, self.S, self.S.status.time_size, tend, Tend, comm=comm_active) 

140 

141 # set new time 

142 all_dt = comm_active.allgather(self.S.dt) 

143 time = tend + sum(all_dt[: self.S.status.slot]) 

144 

145 active = time < Tend - 10 * np.finfo(float).eps 

146 

147 # check if we need to split the communicator 

148 if tend + sum(all_dt[: comm_active.size - 1]) >= Tend - 10 * np.finfo(float).eps: 

149 comm_active_new = comm_active.Split(active) 

150 comm_active.Free() 

151 comm_active = comm_active_new 

152 

153 self.S.status.slot = comm_active.rank 

154 

155 # initialize block of steps with u0 

156 if active: 

157 self.restart_block(comm_active.size, time, uend, comm=comm_active) 

158 

159 # call post-run hook 

160 for hook in self.hooks: 

161 hook.post_run(step=self.S, level_number=0) 

162 

163 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

164 C.post_run_processing(self, self.S, comm=self.comm) 

165 

166 comm_active.Free() 

167 

168 return uend, self.return_stats() 

169 

170 def restart_block(self, size, time, u0, comm): 

171 """ 

172 Helper routine to reset/restart block of (active) steps 

173 

174 Args: 

175 size: number of active time steps 

176 time: current time 

177 u0: initial value to distribute across the steps 

178 comm: the communicator 

179 

180 Returns: 

181 block of (all) steps 

182 """ 

183 

184 # store link to previous step 

185 self.S.prev = (self.S.status.slot - 1) % size 

186 self.S.next = (self.S.status.slot + 1) % size 

187 

188 # resets step 

189 self.S.reset_step() 

190 # determine whether I am the first and/or last in line 

191 self.S.status.first = self.S.prev == size - 1 

192 self.S.status.last = self.S.next == 0 

193 # initialize step with u0 

194 self.S.init_step(u0) 

195 # reset some values 

196 self.S.status.done = False 

197 self.S.status.iter = 0 

198 self.S.status.stage = 'SPREAD' 

199 for l in self.S.levels: 

200 l.tag = None 

201 self.req_status = None 

202 self.req_diff = None 

203 self.req_ibcast = None 

204 self.req_diff = None 

205 self.req_send = [None] * len(self.S.levels) 

206 self.S.status.prev_done = False 

207 self.S.status.force_done = False 

208 

209 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

210 C.reset_status_variables(self, comm=comm) 

211 

212 self.S.status.time_size = size 

213 

214 for lvl in self.S.levels: 

215 lvl.status.time = time 

216 lvl.status.sweep = 1 

217 

218 def recv(self, target, source, tag=None, comm=None): 

219 """ 

220 Receive function 

221 

222 Args: 

223 target: level which will receive the values 

224 source: level which initiated the send 

225 tag: identifier to check if this message is really for me 

226 comm: communicator 

227 """ 

228 req = target.u[0].irecv(source=source, tag=tag, comm=comm) 

229 self.wait_with_interrupt(request=req) 

230 if self.S.status.force_done: 

231 return None 

232 # re-evaluate f on left interval boundary 

233 target.f[0] = target.prob.eval_f(target.u[0], target.time) 

234 

235 def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False): 

236 """ 

237 Function to perform the send, including bookkeeping and logging 

238 

239 Args: 

240 comm: the communicator 

241 blocking: flag to indicate that we need blocking communication 

242 level: the level number 

243 add_to_stats: a flag to end recording data in the hooks (defaults to False) 

244 """ 

245 for hook in self.hooks: 

246 hook.pre_comm(step=self.S, level_number=level) 

247 

248 if not blocking: 

249 self.wait_with_interrupt(request=self.req_send[level]) 

250 if self.S.status.force_done: 

251 return None 

252 

253 self.S.levels[level].sweep.compute_end_point() 

254 

255 if not self.S.status.last: 

256 self.logger.debug( 

257 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' 

258 % ( 

259 self.S.status.slot, 

260 self.S.status.stage, 

261 self.S.time, 

262 self.S.next, 

263 level * 100 + self.S.status.iter, 

264 self.S.status.iter, 

265 ) 

266 ) 

267 self.req_send[level] = self.S.levels[level].uend.isend( 

268 dest=self.S.next, tag=level * 100 + self.S.status.iter, comm=comm 

269 ) 

270 if blocking: 

271 self.wait_with_interrupt(request=self.req_send[level]) 

272 if self.S.status.force_done: 

273 return None 

274 

275 for hook in self.hooks: 

276 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) 

277 

278 def recv_full(self, comm, level=None, add_to_stats=False): 

279 """ 

280 Function to perform the recv, including bookkeeping and logging 

281 

282 Args: 

283 comm: the communicator 

284 level: the level number 

285 add_to_stats: a flag to end recording data in the hooks (defaults to False) 

286 """ 

287 

288 for hook in self.hooks: 

289 hook.pre_comm(step=self.S, level_number=level) 

290 if not self.S.status.first and not self.S.status.prev_done: 

291 self.logger.debug( 

292 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' 

293 % ( 

294 self.S.status.slot, 

295 self.S.status.stage, 

296 self.S.time, 

297 self.S.prev, 

298 level * 100 + self.S.status.iter, 

299 self.S.status.iter, 

300 ) 

301 ) 

302 self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm) 

303 

304 for hook in self.hooks: 

305 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) 

306 

307 def wait_with_interrupt(self, request): 

308 """ 

309 Wrapper for waiting for the completion of a non-blocking communication, can be interrupted 

310 

311 Args: 

312 request: request to wait for 

313 """ 

314 if request is not None and self.req_ibcast is not None: 

315 while not request.Test(): 

316 if self.req_ibcast.Test(): 

317 self.logger.debug(f'{self.S.status.slot} has been cancelled during {self.S.status.stage}..') 

318 self.S.status.stage = f'CANCELLED_{self.S.status.stage}' 

319 self.S.status.force_done = True 

320 return None 

321 if request is not None: 

322 request.Wait() 

323 

324 def check_iteration_estimate(self, comm): 

325 """ 

326 Routine to compute and check error/iteration estimation 

327 

328 Args: 

329 comm: time-communicator 

330 """ 

331 

332 # Compute diff between old and new values 

333 diff_new = 0.0 

334 L = self.S.levels[0] 

335 

336 for m in range(1, L.sweep.coll.num_nodes + 1): 

337 diff_new = max(diff_new, abs(L.uold[m] - L.u[m])) 

338 

339 # Send forward diff 

340 for hook in self.hooks: 

341 hook.pre_comm(step=self.S, level_number=0) 

342 

343 self.wait_with_interrupt(request=self.req_diff) 

344 if self.S.status.force_done: 

345 return None 

346 

347 if not self.S.status.first: 

348 prev_diff = np.empty(1, dtype=float) 

349 req = comm.Irecv((prev_diff, MPI.DOUBLE), source=self.S.prev, tag=999) 

350 self.wait_with_interrupt(request=req) 

351 if self.S.status.force_done: 

352 return None 

353 self.logger.debug( 

354 'recv diff: status %s, process %s, time %s, source %s, tag %s, iter %s' 

355 % (prev_diff, self.S.status.slot, self.S.time, self.S.prev, 999, self.S.status.iter) 

356 ) 

357 diff_new = max(prev_diff[0], diff_new) 

358 

359 if not self.S.status.last: 

360 self.logger.debug( 

361 'isend diff: status %s, process %s, time %s, target %s, tag %s, iter %s' 

362 % (diff_new, self.S.status.slot, self.S.time, self.S.next, 999, self.S.status.iter) 

363 ) 

364 tmp = np.array(diff_new, dtype=float) 

365 self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999) 

366 

367 for hook in self.hooks: 

368 hook.post_comm(step=self.S, level_number=0) 

369 

370 # Store values from first iteration 

371 if self.S.status.iter == 1: 

372 self.S.status.diff_old_loc = diff_new 

373 self.S.status.diff_first_loc = diff_new 

374 # Compute iteration estimate 

375 elif self.S.status.iter > 1: 

376 Ltilde_loc = min(diff_new / self.S.status.diff_old_loc, 0.9) 

377 self.S.status.diff_old_loc = diff_new 

378 alpha = 1 / (1 - Ltilde_loc) * self.S.status.diff_first_loc 

379 Kest_loc = np.log(self.S.params.errtol / alpha) / np.log(Ltilde_loc) * 1.05 # Safety factor! 

380 self.logger.debug( 

381 f'LOCAL: {L.time:8.4f}, {self.S.status.iter}: {int(np.ceil(Kest_loc))}, ' 

382 f'{Ltilde_loc:8.6e}, {Kest_loc:8.6e}, ' 

383 f'{Ltilde_loc ** self.S.status.iter * alpha:8.6e}' 

384 ) 

385 Kest_glob = Kest_loc 

386 # If condition is met, send interrupt 

387 if np.ceil(Kest_glob) <= self.S.status.iter: 

388 if self.S.status.last: 

389 self.logger.debug(f'{self.S.status.slot} is done, broadcasting..') 

390 for hook in self.hooks: 

391 hook.pre_comm(step=self.S, level_number=0) 

392 comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait() 

393 for hook in self.hooks: 

394 hook.post_comm(step=self.S, level_number=0, add_to_stats=True) 

395 self.logger.debug(f'{self.S.status.slot} is done, broadcasting done') 

396 self.S.status.done = True 

397 else: 

398 for hook in self.hooks: 

399 hook.pre_comm(step=self.S, level_number=0) 

400 for hook in self.hooks: 

401 hook.post_comm(step=self.S, level_number=0, add_to_stats=True) 

402 

403 def pfasst(self, comm, num_procs): 

404 """ 

405 Main function including the stages of SDC, MLSDC and PFASST (the "controller") 

406 

407 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper 

408 

409 Args: 

410 comm: communicator 

411 num_procs (int): number of parallel processes 

412 """ 

413 

414 stage = self.S.status.stage 

415 

416 self.logger.debug(stage + ' - process ' + str(self.S.status.slot)) 

417 

418 # Wait for interrupt, if iteration estimator is used 

419 if self.params.use_iteration_estimator and stage == 'SPREAD' and not self.S.status.last: 

420 done = np.empty(1) 

421 self.req_ibcast = comm.Ibcast((done, MPI.INT), root=comm.Get_size() - 1) 

422 

423 # If interrupt is there, cleanup and finish 

424 if self.params.use_iteration_estimator and not self.S.status.last and self.req_ibcast.Test(): 

425 self.logger.debug(f'{self.S.status.slot} is done..') 

426 self.S.status.done = True 

427 

428 if not stage == 'IT_CHECK': 

429 self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..') 

430 self.S.levels[0].u[1:] = self.S.levels[0].uold[1:] 

431 

432 for hook in self.hooks: 

433 hook.post_iteration(step=self.S, level_number=0) 

434 

435 for req in self.req_send: 

436 if req is not None and req != MPI.REQUEST_NULL: 

437 req.Cancel() 

438 if self.req_status is not None and self.req_status != MPI.REQUEST_NULL: 

439 self.req_status.Cancel() 

440 if self.req_diff is not None and self.req_diff != MPI.REQUEST_NULL: 

441 self.req_diff.Cancel() 

442 

443 self.S.status.stage = 'DONE' 

444 for hook in self.hooks: 

445 hook.post_step(step=self.S, level_number=0) 

446 

447 else: 

448 # Start cycling, if not interrupted 

449 switcher = { 

450 'SPREAD': self.spread, 

451 'PREDICT': self.predict, 

452 'IT_CHECK': self.it_check, 

453 'IT_FINE': self.it_fine, 

454 'IT_DOWN': self.it_down, 

455 'IT_COARSE': self.it_coarse, 

456 'IT_UP': self.it_up, 

457 } 

458 

459 switcher.get(stage, self.default)(comm, num_procs) 

460 

461 def spread(self, comm, num_procs): 

462 """ 

463 Spreading phase 

464 """ 

465 

466 # first stage: spread values 

467 for hook in self.hooks: 

468 hook.pre_step(step=self.S, level_number=0) 

469 

470 # call predictor from sweeper 

471 self.S.levels[0].sweep.predict() 

472 

473 if self.params.use_iteration_estimator: 

474 # store previous iterate to compute difference later on 

475 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:] 

476 

477 # update stage 

478 if len(self.S.levels) > 1: # MLSDC or PFASST with predict 

479 self.S.status.stage = 'PREDICT' 

480 else: 

481 self.S.status.stage = 'IT_CHECK' 

482 

483 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

484 C.post_spread_processing(self, self.S, comm=comm) 

485 

486 def predict(self, comm, num_procs): 

487 """ 

488 Predictor phase 

489 """ 

490 

491 for hook in self.hooks: 

492 hook.pre_predict(step=self.S, level_number=0) 

493 

494 if self.params.predict_type is None: 

495 pass 

496 

497 elif self.params.predict_type == 'fine_only': 

498 # do a fine sweep only 

499 self.S.levels[0].sweep.update_nodes() 

500 

501 # elif self.params.predict_type == 'libpfasst_style': 

502 # 

503 # # restrict to coarsest level 

504 # for l in range(1, len(self.S.levels)): 

505 # self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l]) 

506 # 

507 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1) 

508 # if not self.S.status.first: 

509 # self.logger.debug('recv data predict: process %s, stage %s, time, %s, source %s, tag %s' % 

510 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, 

511 # self.S.status.iter)) 

512 # self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm) 

513 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1) 

514 # 

515 # # do the sweep with new values 

516 # self.S.levels[-1].sweep.update_nodes() 

517 # self.S.levels[-1].sweep.compute_end_point() 

518 # 

519 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1) 

520 # if not self.S.status.last: 

521 # self.logger.debug('send data predict: process %s, stage %s, time, %s, target %s, tag %s' % 

522 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, 

523 # self.S.status.iter)) 

524 # self.S.levels[-1].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm).Wait() 

525 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True) 

526 # 

527 # # go back to fine level, sweeping 

528 # for l in range(len(self.S.levels) - 1, 0, -1): 

529 # # prolong values 

530 # self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

531 # # on middle levels: do sweep as usual 

532 # if l - 1 > 0: 

533 # self.S.levels[l - 1].sweep.update_nodes() 

534 # 

535 # # end with a fine sweep 

536 # self.S.levels[0].sweep.update_nodes() 

537 

538 elif self.params.predict_type == 'pfasst_burnin': 

539 # restrict to coarsest level 

540 for l in range(1, len(self.S.levels)): 

541 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l]) 

542 

543 for p in range(self.S.status.slot + 1): 

544 if not p == 0: 

545 self.recv_full(comm=comm, level=len(self.S.levels) - 1) 

546 if self.S.status.force_done: 

547 return None 

548 

549 # do the sweep with new values 

550 self.S.levels[-1].sweep.update_nodes() 

551 self.S.levels[-1].sweep.compute_end_point() 

552 

553 self.send_full( 

554 comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=(p == self.S.status.slot) 

555 ) 

556 if self.S.status.force_done: 

557 return None 

558 

559 # interpolate back to finest level 

560 for l in range(len(self.S.levels) - 1, 0, -1): 

561 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

562 

563 self.send_full(comm=comm, level=0) 

564 if self.S.status.force_done: 

565 return None 

566 

567 self.recv_full(comm=comm, level=0) 

568 if self.S.status.force_done: 

569 return None 

570 

571 # end this with a fine sweep 

572 self.S.levels[0].sweep.update_nodes() 

573 

574 elif self.params.predict_type == 'fmg': 

575 # TODO: implement FMG predictor 

576 raise NotImplementedError('FMG predictor is not yet implemented') 

577 

578 else: 

579 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type) 

580 

581 for hook in self.hooks: 

582 hook.post_predict(step=self.S, level_number=0) 

583 

584 # update stage 

585 self.S.status.stage = 'IT_CHECK' 

586 

587 def it_check(self, comm, num_procs): 

588 """ 

589 Key routine to check for convergence/termination 

590 """ 

591 

592 # Update values to compute the residual 

593 self.send_full(comm=comm, level=0) 

594 if self.S.status.force_done: 

595 return None 

596 

597 self.recv_full(comm=comm, level=0) 

598 if self.S.status.force_done: 

599 return None 

600 

601 # compute the residual 

602 self.S.levels[0].sweep.compute_residual(stage='IT_CHECK') 

603 

604 if self.params.use_iteration_estimator: 

605 # TODO: replace with convergence controller 

606 self.check_iteration_estimate(comm=comm) 

607 

608 if self.S.status.force_done: 

609 return None 

610 

611 if self.S.status.iter > 0: 

612 for hook in self.hooks: 

613 hook.post_iteration(step=self.S, level_number=0) 

614 

615 # decide if the step is done, needs to be restarted and other things convergence related 

616 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

617 C.post_iteration_processing(self, self.S, comm=comm) 

618 C.convergence_control(self, self.S, comm=comm) 

619 

620 # if not ready, keep doing stuff 

621 if not self.S.status.done: 

622 # increment iteration count here (and only here) 

623 self.S.status.iter += 1 

624 

625 for hook in self.hooks: 

626 hook.pre_iteration(step=self.S, level_number=0) 

627 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

628 C.pre_iteration_processing(self, self.S, comm=comm) 

629 

630 if self.params.use_iteration_estimator: 

631 # store previous iterate to compute difference later on 

632 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:] 

633 

634 if len(self.S.levels) > 1: # MLSDC or PFASST 

635 self.S.status.stage = 'IT_DOWN' 

636 else: 

637 if num_procs == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like) 

638 self.S.status.stage = 'IT_FINE' 

639 else: 

640 self.S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like) 

641 

642 else: 

643 if not self.params.use_iteration_estimator: 

644 # Need to finish all pending isend requests. These will occur for the first active process, since 

645 # in the last iteration the wait statement will not be called ("send and forget") 

646 for req in self.req_send: 

647 if req is not None: 

648 req.Wait() 

649 if self.req_status is not None: 

650 self.req_status.Wait() 

651 if self.req_diff is not None: 

652 self.req_diff.Wait() 

653 else: 

654 for req in self.req_send: 

655 if req is not None: 

656 req.Cancel() 

657 if self.req_status is not None: 

658 self.req_status.Cancel() 

659 if self.req_diff is not None: 

660 self.req_diff.Cancel() 

661 

662 for hook in self.hooks: 

663 hook.post_step(step=self.S, level_number=0) 

664 self.S.status.stage = 'DONE' 

665 

666 def it_fine(self, comm, num_procs): 

667 """ 

668 Fine sweeps 

669 """ 

670 

671 nsweeps = self.S.levels[0].params.nsweeps 

672 

673 self.S.levels[0].status.sweep = 0 

674 

675 # do fine sweep 

676 for k in range(nsweeps): 

677 self.S.levels[0].status.sweep += 1 

678 

679 # send values forward 

680 self.send_full(comm=comm, level=0) 

681 if self.S.status.force_done: 

682 return None 

683 

684 # recv values from previous 

685 self.recv_full(comm=comm, level=0, add_to_stats=(k == nsweeps - 1)) 

686 if self.S.status.force_done: 

687 return None 

688 

689 for hook in self.hooks: 

690 hook.pre_sweep(step=self.S, level_number=0) 

691 self.S.levels[0].sweep.update_nodes() 

692 self.S.levels[0].sweep.compute_residual(stage='IT_FINE') 

693 for hook in self.hooks: 

694 hook.post_sweep(step=self.S, level_number=0) 

695 

696 # update stage 

697 self.S.status.stage = 'IT_CHECK' 

698 

699 def it_down(self, comm, num_procs): 

700 """ 

701 Go down the hierarchy from finest to coarsest level 

702 """ 

703 

704 self.S.transfer(source=self.S.levels[0], target=self.S.levels[1]) 

705 

706 # sweep and send on middle levels (not on finest, not on coarsest, though) 

707 for l in range(1, len(self.S.levels) - 1): 

708 nsweeps = self.S.levels[l].params.nsweeps 

709 

710 for _ in range(nsweeps): 

711 self.send_full(comm=comm, level=l) 

712 if self.S.status.force_done: 

713 return None 

714 

715 self.recv_full(comm=comm, level=l) 

716 if self.S.status.force_done: 

717 return None 

718 

719 for hook in self.hooks: 

720 hook.pre_sweep(step=self.S, level_number=l) 

721 self.S.levels[l].sweep.update_nodes() 

722 self.S.levels[l].sweep.compute_residual(stage='IT_DOWN') 

723 for hook in self.hooks: 

724 hook.post_sweep(step=self.S, level_number=l) 

725 

726 # transfer further down the hierarchy 

727 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1]) 

728 

729 # update stage 

730 self.S.status.stage = 'IT_COARSE' 

731 

732 def it_coarse(self, comm, num_procs): 

733 """ 

734 Coarse sweep 

735 """ 

736 

737 # receive from previous step (if not first) 

738 self.recv_full(comm=comm, level=len(self.S.levels) - 1) 

739 if self.S.status.force_done: 

740 return None 

741 

742 # do the sweep 

743 for hook in self.hooks: 

744 hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1) 

745 assert self.S.levels[-1].params.nsweeps == 1, ( 

746 'ERROR: this controller can only work with one sweep on the coarse level, got %s' 

747 % self.S.levels[-1].params.nsweeps 

748 ) 

749 self.S.levels[-1].sweep.update_nodes() 

750 self.S.levels[-1].sweep.compute_residual(stage='IT_COARSE') 

751 for hook in self.hooks: 

752 hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1) 

753 self.S.levels[-1].sweep.compute_end_point() 

754 

755 # send to next step 

756 self.send_full(comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=True) 

757 if self.S.status.force_done: 

758 return None 

759 

760 # update stage 

761 if len(self.S.levels) > 1: # MLSDC or PFASST 

762 self.S.status.stage = 'IT_UP' 

763 else: 

764 self.S.status.stage = 'IT_CHECK' # MSSDC 

765 

766 def it_up(self, comm, num_procs): 

767 """ 

768 Prolong corrections up to finest level (parallel) 

769 """ 

770 

771 # receive and sweep on middle levels (except for coarsest level) 

772 for l in range(len(self.S.levels) - 1, 0, -1): 

773 # prolong values 

774 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

775 

776 # on middle levels: do sweep as usual 

777 if l - 1 > 0: 

778 nsweeps = self.S.levels[l - 1].params.nsweeps 

779 

780 for k in range(nsweeps): 

781 self.send_full(comm, level=l - 1) 

782 if self.S.status.force_done: 

783 return None 

784 

785 self.recv_full(comm=comm, level=l - 1, add_to_stats=(k == nsweeps - 1)) 

786 if self.S.status.force_done: 

787 return None 

788 

789 for hook in self.hooks: 

790 hook.pre_sweep(step=self.S, level_number=l - 1) 

791 self.S.levels[l - 1].sweep.update_nodes() 

792 self.S.levels[l - 1].sweep.compute_residual(stage='IT_UP') 

793 for hook in self.hooks: 

794 hook.post_sweep(step=self.S, level_number=l - 1) 

795 

796 # update stage 

797 self.S.status.stage = 'IT_FINE' 

798 

799 def default(self, num_procs): 

800 """ 

801 Default routine to catch wrong status 

802 """ 

803 raise ControllerError('Weird stage, got %s' % self.S.status.stage)