Coverage for pySDC/implementations/controller_classes/controller_MPI.py: 66%

384 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import numpy as np 

2from mpi4py import MPI 

3 

4from pySDC.core.controller import Controller 

5from pySDC.core.errors import ControllerError 

6from pySDC.core.step import Step 

7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting 

8 

9 

10class controller_MPI(Controller): 

11 """ 

12 

13 PFASST controller, running parallel version of PFASST in blocks (MG-style) 

14 

15 """ 

16 

17 def __init__(self, controller_params, description, comm): 

18 """ 

19 Initialization routine for PFASST controller 

20 

21 Args: 

22 controller_params: parameter set for the controller and the step class 

23 description: all the parameters to set up the rest (levels, problems, transfer, ...) 

24 comm: MPI communicator 

25 """ 

26 

27 # call parent's initialization routine 

28 super().__init__(controller_params, description, useMPI=True) 

29 

30 # create single step per processor 

31 self.S = Step(description) 

32 

33 # pass communicator for future use 

34 self.comm = comm 

35 

36 num_procs = self.comm.Get_size() 

37 rank = self.comm.Get_rank() 

38 

39 # insert data on time communicator to the steps (helpful here and there) 

40 self.S.status.time_size = num_procs 

41 

42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=True)] 

43 for convergence_controller in self.base_convergence_controllers: 

44 self.add_convergence_controller(convergence_controller, description) 

45 

46 if self.params.dump_setup and rank == 0: 

47 self.dump_setup(step=self.S, controller_params=controller_params, description=description) 

48 

49 num_levels = len(self.S.levels) 

50 

51 # add request handler for status send 

52 self.req_status = None 

53 # add request handle container for isend 

54 self.req_send = [None] * num_levels 

55 self.req_ibcast = None 

56 self.req_diff = None 

57 

58 if num_procs > 1 and num_levels > 1: 

59 for L in self.S.levels: 

60 if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update: 

61 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k") 

62 

63 if num_levels == 1 and self.params.predict_type is not None: 

64 self.logger.warning( 

65 'you have specified a predictor type but only a single level.. predictor will be ignored' 

66 ) 

67 

68 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

69 C.setup_status_variables(self, comm=comm) 

70 

71 def run(self, u0, t0, Tend): 

72 """ 

73 Main driver for running the parallel version of SDC, MSSDC, MLSDC and PFASST 

74 

75 Args: 

76 u0: initial values 

77 t0: starting time 

78 Tend: ending time 

79 

80 Returns: 

81 end values on the finest level 

82 stats object containing statistics for each step, each level and each iteration 

83 """ 

84 

85 # reset stats to prevent double entries from old runs 

86 for hook in self.hooks: 

87 hook.reset_stats() 

88 

89 # setup time initially 

90 all_dt = self.comm.allgather(self.S.dt) 

91 time = t0 + sum(all_dt[: self.comm.rank]) 

92 

93 active = time < Tend - 10 * np.finfo(float).eps 

94 comm_active = self.comm.Split(active) 

95 self.S.status.slot = comm_active.rank 

96 

97 if self.comm.rank == 0 and not active: 

98 raise ControllerError('Nothing to do, check t0, dt and Tend!') 

99 

100 # initialize block of steps with u0 

101 self.restart_block(comm_active.size, time, u0, comm=comm_active) 

102 uend = u0 

103 

104 # call post-setup hook 

105 for hook in self.hooks: 

106 hook.post_setup(step=None, level_number=None) 

107 

108 # call pre-run hook 

109 for hook in self.hooks: 

110 hook.pre_run(step=self.S, level_number=0) 

111 

112 comm_active.Barrier() 

113 

114 # while any process still active... 

115 while active: 

116 while not self.S.status.done: 

117 self.pfasst(comm_active, comm_active.size) 

118 

119 # determine where to restart 

120 restarts = comm_active.allgather(self.S.status.restart) 

121 

122 # communicate time and solution to be used as next initial conditions 

123 if True in restarts: 

124 restart_at = np.where(restarts)[0][0] 

125 uend = self.S.levels[0].u[0].bcast(root=restart_at, comm=comm_active) 

126 tend = comm_active.bcast(self.S.time, root=restart_at) 

127 self.logger.info(f'Starting next block with initial conditions from step {restart_at}') 

128 

129 else: 

130 uend = self.S.levels[0].uend.bcast(root=comm_active.size - 1, comm=comm_active) 

131 tend = comm_active.bcast(self.S.time + self.S.dt, root=comm_active.size - 1) 

132 

133 # do convergence controller stuff 

134 if not self.S.status.restart: 

135 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

136 C.post_step_processing(self, self.S, comm=comm_active) 

137 

138 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

139 C.prepare_next_block(self, self.S, self.S.status.time_size, tend, Tend, comm=comm_active) 

140 

141 # set new time 

142 all_dt = comm_active.allgather(self.S.dt) 

143 time = tend + sum(all_dt[: self.S.status.slot]) 

144 

145 active = time < Tend - 10 * np.finfo(float).eps 

146 

147 # check if we need to split the communicator 

148 if tend + sum(all_dt[: comm_active.size - 1]) >= Tend - 10 * np.finfo(float).eps: 

149 comm_active_new = comm_active.Split(active) 

150 comm_active.Free() 

151 comm_active = comm_active_new 

152 

153 self.S.status.slot = comm_active.rank 

154 

155 # initialize block of steps with u0 

156 if active: 

157 self.restart_block(comm_active.size, time, uend, comm=comm_active) 

158 

159 # call post-run hook 

160 for hook in self.hooks: 

161 hook.post_run(step=self.S, level_number=0) 

162 

163 comm_active.Free() 

164 

165 return uend, self.return_stats() 

166 

167 def restart_block(self, size, time, u0, comm): 

168 """ 

169 Helper routine to reset/restart block of (active) steps 

170 

171 Args: 

172 size: number of active time steps 

173 time: current time 

174 u0: initial value to distribute across the steps 

175 comm: the communicator 

176 

177 Returns: 

178 block of (all) steps 

179 """ 

180 

181 # store link to previous step 

182 self.S.prev = (self.S.status.slot - 1) % size 

183 self.S.next = (self.S.status.slot + 1) % size 

184 

185 # resets step 

186 self.S.reset_step() 

187 # determine whether I am the first and/or last in line 

188 self.S.status.first = self.S.prev == size - 1 

189 self.S.status.last = self.S.next == 0 

190 # initialize step with u0 

191 self.S.init_step(u0) 

192 # reset some values 

193 self.S.status.done = False 

194 self.S.status.iter = 0 

195 self.S.status.stage = 'SPREAD' 

196 for l in self.S.levels: 

197 l.tag = None 

198 self.req_status = None 

199 self.req_diff = None 

200 self.req_ibcast = None 

201 self.req_diff = None 

202 self.req_send = [None] * len(self.S.levels) 

203 self.S.status.prev_done = False 

204 self.S.status.force_done = False 

205 

206 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

207 C.reset_status_variables(self, comm=comm) 

208 

209 self.S.status.time_size = size 

210 

211 for lvl in self.S.levels: 

212 lvl.status.time = time 

213 lvl.status.sweep = 1 

214 

215 def recv(self, target, source, tag=None, comm=None): 

216 """ 

217 Receive function 

218 

219 Args: 

220 target: level which will receive the values 

221 source: level which initiated the send 

222 tag: identifier to check if this message is really for me 

223 comm: communicator 

224 """ 

225 req = target.u[0].irecv(source=source, tag=tag, comm=comm) 

226 self.wait_with_interrupt(request=req) 

227 if self.S.status.force_done: 

228 return None 

229 # re-evaluate f on left interval boundary 

230 target.f[0] = target.prob.eval_f(target.u[0], target.time) 

231 

232 def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False): 

233 """ 

234 Function to perform the send, including bookkeeping and logging 

235 

236 Args: 

237 comm: the communicator 

238 blocking: flag to indicate that we need blocking communication 

239 level: the level number 

240 add_to_stats: a flag to end recording data in the hooks (defaults to False) 

241 """ 

242 for hook in self.hooks: 

243 hook.pre_comm(step=self.S, level_number=level) 

244 

245 if not blocking: 

246 self.wait_with_interrupt(request=self.req_send[level]) 

247 if self.S.status.force_done: 

248 return None 

249 

250 self.S.levels[level].sweep.compute_end_point() 

251 

252 if not self.S.status.last: 

253 self.logger.debug( 

254 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' 

255 % ( 

256 self.S.status.slot, 

257 self.S.status.stage, 

258 self.S.time, 

259 self.S.next, 

260 level * 100 + self.S.status.iter, 

261 self.S.status.iter, 

262 ) 

263 ) 

264 self.req_send[level] = self.S.levels[level].uend.isend( 

265 dest=self.S.next, tag=level * 100 + self.S.status.iter, comm=comm 

266 ) 

267 if blocking: 

268 self.wait_with_interrupt(request=self.req_send[level]) 

269 if self.S.status.force_done: 

270 return None 

271 

272 for hook in self.hooks: 

273 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) 

274 

275 def recv_full(self, comm, level=None, add_to_stats=False): 

276 """ 

277 Function to perform the recv, including bookkeeping and logging 

278 

279 Args: 

280 comm: the communicator 

281 level: the level number 

282 add_to_stats: a flag to end recording data in the hooks (defaults to False) 

283 """ 

284 

285 for hook in self.hooks: 

286 hook.pre_comm(step=self.S, level_number=level) 

287 if not self.S.status.first and not self.S.status.prev_done: 

288 self.logger.debug( 

289 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' 

290 % ( 

291 self.S.status.slot, 

292 self.S.status.stage, 

293 self.S.time, 

294 self.S.prev, 

295 level * 100 + self.S.status.iter, 

296 self.S.status.iter, 

297 ) 

298 ) 

299 self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm) 

300 

301 for hook in self.hooks: 

302 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) 

303 

304 def wait_with_interrupt(self, request): 

305 """ 

306 Wrapper for waiting for the completion of a non-blocking communication, can be interrupted 

307 

308 Args: 

309 request: request to wait for 

310 """ 

311 if request is not None and self.req_ibcast is not None: 

312 while not request.Test(): 

313 if self.req_ibcast.Test(): 

314 self.logger.debug(f'{self.S.status.slot} has been cancelled during {self.S.status.stage}..') 

315 self.S.status.stage = f'CANCELLED_{self.S.status.stage}' 

316 self.S.status.force_done = True 

317 return None 

318 if request is not None: 

319 request.Wait() 

320 

321 def check_iteration_estimate(self, comm): 

322 """ 

323 Routine to compute and check error/iteration estimation 

324 

325 Args: 

326 comm: time-communicator 

327 """ 

328 

329 # Compute diff between old and new values 

330 diff_new = 0.0 

331 L = self.S.levels[0] 

332 

333 for m in range(1, L.sweep.coll.num_nodes + 1): 

334 diff_new = max(diff_new, abs(L.uold[m] - L.u[m])) 

335 

336 # Send forward diff 

337 for hook in self.hooks: 

338 hook.pre_comm(step=self.S, level_number=0) 

339 

340 self.wait_with_interrupt(request=self.req_diff) 

341 if self.S.status.force_done: 

342 return None 

343 

344 if not self.S.status.first: 

345 prev_diff = np.empty(1, dtype=float) 

346 req = comm.Irecv((prev_diff, MPI.DOUBLE), source=self.S.prev, tag=999) 

347 self.wait_with_interrupt(request=req) 

348 if self.S.status.force_done: 

349 return None 

350 self.logger.debug( 

351 'recv diff: status %s, process %s, time %s, source %s, tag %s, iter %s' 

352 % (prev_diff, self.S.status.slot, self.S.time, self.S.prev, 999, self.S.status.iter) 

353 ) 

354 diff_new = max(prev_diff[0], diff_new) 

355 

356 if not self.S.status.last: 

357 self.logger.debug( 

358 'isend diff: status %s, process %s, time %s, target %s, tag %s, iter %s' 

359 % (diff_new, self.S.status.slot, self.S.time, self.S.next, 999, self.S.status.iter) 

360 ) 

361 tmp = np.array(diff_new, dtype=float) 

362 self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999) 

363 

364 for hook in self.hooks: 

365 hook.post_comm(step=self.S, level_number=0) 

366 

367 # Store values from first iteration 

368 if self.S.status.iter == 1: 

369 self.S.status.diff_old_loc = diff_new 

370 self.S.status.diff_first_loc = diff_new 

371 # Compute iteration estimate 

372 elif self.S.status.iter > 1: 

373 Ltilde_loc = min(diff_new / self.S.status.diff_old_loc, 0.9) 

374 self.S.status.diff_old_loc = diff_new 

375 alpha = 1 / (1 - Ltilde_loc) * self.S.status.diff_first_loc 

376 Kest_loc = np.log(self.S.params.errtol / alpha) / np.log(Ltilde_loc) * 1.05 # Safety factor! 

377 self.logger.debug( 

378 f'LOCAL: {L.time:8.4f}, {self.S.status.iter}: {int(np.ceil(Kest_loc))}, ' 

379 f'{Ltilde_loc:8.6e}, {Kest_loc:8.6e}, ' 

380 f'{Ltilde_loc ** self.S.status.iter * alpha:8.6e}' 

381 ) 

382 Kest_glob = Kest_loc 

383 # If condition is met, send interrupt 

384 if np.ceil(Kest_glob) <= self.S.status.iter: 

385 if self.S.status.last: 

386 self.logger.debug(f'{self.S.status.slot} is done, broadcasting..') 

387 for hook in self.hooks: 

388 hook.pre_comm(step=self.S, level_number=0) 

389 comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait() 

390 for hook in self.hooks: 

391 hook.post_comm(step=self.S, level_number=0, add_to_stats=True) 

392 self.logger.debug(f'{self.S.status.slot} is done, broadcasting done') 

393 self.S.status.done = True 

394 else: 

395 for hook in self.hooks: 

396 hook.pre_comm(step=self.S, level_number=0) 

397 for hook in self.hooks: 

398 hook.post_comm(step=self.S, level_number=0, add_to_stats=True) 

399 

400 def pfasst(self, comm, num_procs): 

401 """ 

402 Main function including the stages of SDC, MLSDC and PFASST (the "controller") 

403 

404 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper 

405 

406 Args: 

407 comm: communicator 

408 num_procs (int): number of parallel processes 

409 """ 

410 

411 stage = self.S.status.stage 

412 

413 self.logger.debug(stage + ' - process ' + str(self.S.status.slot)) 

414 

415 # Wait for interrupt, if iteration estimator is used 

416 if self.params.use_iteration_estimator and stage == 'SPREAD' and not self.S.status.last: 

417 done = np.empty(1) 

418 self.req_ibcast = comm.Ibcast((done, MPI.INT), root=comm.Get_size() - 1) 

419 

420 # If interrupt is there, cleanup and finish 

421 if self.params.use_iteration_estimator and not self.S.status.last and self.req_ibcast.Test(): 

422 self.logger.debug(f'{self.S.status.slot} is done..') 

423 self.S.status.done = True 

424 

425 if not stage == 'IT_CHECK': 

426 self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..') 

427 self.S.levels[0].u[1:] = self.S.levels[0].uold[1:] 

428 

429 for hook in self.hooks: 

430 hook.post_iteration(step=self.S, level_number=0) 

431 

432 for req in self.req_send: 

433 if req is not None and req != MPI.REQUEST_NULL: 

434 req.Cancel() 

435 if self.req_status is not None and self.req_status != MPI.REQUEST_NULL: 

436 self.req_status.Cancel() 

437 if self.req_diff is not None and self.req_diff != MPI.REQUEST_NULL: 

438 self.req_diff.Cancel() 

439 

440 self.S.status.stage = 'DONE' 

441 for hook in self.hooks: 

442 hook.post_step(step=self.S, level_number=0) 

443 

444 else: 

445 # Start cycling, if not interrupted 

446 switcher = { 

447 'SPREAD': self.spread, 

448 'PREDICT': self.predict, 

449 'IT_CHECK': self.it_check, 

450 'IT_FINE': self.it_fine, 

451 'IT_DOWN': self.it_down, 

452 'IT_COARSE': self.it_coarse, 

453 'IT_UP': self.it_up, 

454 } 

455 

456 switcher.get(stage, self.default)(comm, num_procs) 

457 

458 def spread(self, comm, num_procs): 

459 """ 

460 Spreading phase 

461 """ 

462 

463 # first stage: spread values 

464 for hook in self.hooks: 

465 hook.pre_step(step=self.S, level_number=0) 

466 

467 # call predictor from sweeper 

468 self.S.levels[0].sweep.predict() 

469 

470 if self.params.use_iteration_estimator: 

471 # store previous iterate to compute difference later on 

472 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:] 

473 

474 # update stage 

475 if len(self.S.levels) > 1: # MLSDC or PFASST with predict 

476 self.S.status.stage = 'PREDICT' 

477 else: 

478 self.S.status.stage = 'IT_CHECK' 

479 

480 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

481 C.post_spread_processing(self, self.S, comm=comm) 

482 

483 def predict(self, comm, num_procs): 

484 """ 

485 Predictor phase 

486 """ 

487 

488 for hook in self.hooks: 

489 hook.pre_predict(step=self.S, level_number=0) 

490 

491 if self.params.predict_type is None: 

492 pass 

493 

494 elif self.params.predict_type == 'fine_only': 

495 # do a fine sweep only 

496 self.S.levels[0].sweep.update_nodes() 

497 

498 # elif self.params.predict_type == 'libpfasst_style': 

499 # 

500 # # restrict to coarsest level 

501 # for l in range(1, len(self.S.levels)): 

502 # self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l]) 

503 # 

504 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1) 

505 # if not self.S.status.first: 

506 # self.logger.debug('recv data predict: process %s, stage %s, time, %s, source %s, tag %s' % 

507 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev, 

508 # self.S.status.iter)) 

509 # self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm) 

510 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1) 

511 # 

512 # # do the sweep with new values 

513 # self.S.levels[-1].sweep.update_nodes() 

514 # self.S.levels[-1].sweep.compute_end_point() 

515 # 

516 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1) 

517 # if not self.S.status.last: 

518 # self.logger.debug('send data predict: process %s, stage %s, time, %s, target %s, tag %s' % 

519 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next, 

520 # self.S.status.iter)) 

521 # self.S.levels[-1].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm).Wait() 

522 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True) 

523 # 

524 # # go back to fine level, sweeping 

525 # for l in range(len(self.S.levels) - 1, 0, -1): 

526 # # prolong values 

527 # self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

528 # # on middle levels: do sweep as usual 

529 # if l - 1 > 0: 

530 # self.S.levels[l - 1].sweep.update_nodes() 

531 # 

532 # # end with a fine sweep 

533 # self.S.levels[0].sweep.update_nodes() 

534 

535 elif self.params.predict_type == 'pfasst_burnin': 

536 # restrict to coarsest level 

537 for l in range(1, len(self.S.levels)): 

538 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l]) 

539 

540 for p in range(self.S.status.slot + 1): 

541 if not p == 0: 

542 self.recv_full(comm=comm, level=len(self.S.levels) - 1) 

543 if self.S.status.force_done: 

544 return None 

545 

546 # do the sweep with new values 

547 self.S.levels[-1].sweep.update_nodes() 

548 self.S.levels[-1].sweep.compute_end_point() 

549 

550 self.send_full( 

551 comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=(p == self.S.status.slot) 

552 ) 

553 if self.S.status.force_done: 

554 return None 

555 

556 # interpolate back to finest level 

557 for l in range(len(self.S.levels) - 1, 0, -1): 

558 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

559 

560 self.send_full(comm=comm, level=0) 

561 if self.S.status.force_done: 

562 return None 

563 

564 self.recv_full(comm=comm, level=0) 

565 if self.S.status.force_done: 

566 return None 

567 

568 # end this with a fine sweep 

569 self.S.levels[0].sweep.update_nodes() 

570 

571 elif self.params.predict_type == 'fmg': 

572 # TODO: implement FMG predictor 

573 raise NotImplementedError('FMG predictor is not yet implemented') 

574 

575 else: 

576 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type) 

577 

578 for hook in self.hooks: 

579 hook.post_predict(step=self.S, level_number=0) 

580 

581 # update stage 

582 self.S.status.stage = 'IT_CHECK' 

583 

584 def it_check(self, comm, num_procs): 

585 """ 

586 Key routine to check for convergence/termination 

587 """ 

588 

589 # Update values to compute the residual 

590 self.send_full(comm=comm, level=0) 

591 if self.S.status.force_done: 

592 return None 

593 

594 self.recv_full(comm=comm, level=0) 

595 if self.S.status.force_done: 

596 return None 

597 

598 # compute the residual 

599 self.S.levels[0].sweep.compute_residual(stage='IT_CHECK') 

600 

601 if self.params.use_iteration_estimator: 

602 # TODO: replace with convergence controller 

603 self.check_iteration_estimate(comm=comm) 

604 

605 if self.S.status.force_done: 

606 return None 

607 

608 if self.S.status.iter > 0: 

609 for hook in self.hooks: 

610 hook.post_iteration(step=self.S, level_number=0) 

611 

612 # decide if the step is done, needs to be restarted and other things convergence related 

613 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

614 C.post_iteration_processing(self, self.S, comm=comm) 

615 C.convergence_control(self, self.S, comm=comm) 

616 

617 # if not ready, keep doing stuff 

618 if not self.S.status.done: 

619 # increment iteration count here (and only here) 

620 self.S.status.iter += 1 

621 

622 for hook in self.hooks: 

623 hook.pre_iteration(step=self.S, level_number=0) 

624 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: 

625 C.pre_iteration_processing(self, self.S, comm=comm) 

626 

627 if self.params.use_iteration_estimator: 

628 # store previous iterate to compute difference later on 

629 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:] 

630 

631 if len(self.S.levels) > 1: # MLSDC or PFASST 

632 self.S.status.stage = 'IT_DOWN' 

633 else: 

634 if num_procs == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like) 

635 self.S.status.stage = 'IT_FINE' 

636 else: 

637 self.S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like) 

638 

639 else: 

640 if not self.params.use_iteration_estimator: 

641 # Need to finish all pending isend requests. These will occur for the first active process, since 

642 # in the last iteration the wait statement will not be called ("send and forget") 

643 for req in self.req_send: 

644 if req is not None: 

645 req.Wait() 

646 if self.req_status is not None: 

647 self.req_status.Wait() 

648 if self.req_diff is not None: 

649 self.req_diff.Wait() 

650 else: 

651 for req in self.req_send: 

652 if req is not None: 

653 req.Cancel() 

654 if self.req_status is not None: 

655 self.req_status.Cancel() 

656 if self.req_diff is not None: 

657 self.req_diff.Cancel() 

658 

659 for hook in self.hooks: 

660 hook.post_step(step=self.S, level_number=0) 

661 self.S.status.stage = 'DONE' 

662 

663 def it_fine(self, comm, num_procs): 

664 """ 

665 Fine sweeps 

666 """ 

667 

668 nsweeps = self.S.levels[0].params.nsweeps 

669 

670 self.S.levels[0].status.sweep = 0 

671 

672 # do fine sweep 

673 for k in range(nsweeps): 

674 self.S.levels[0].status.sweep += 1 

675 

676 # send values forward 

677 self.send_full(comm=comm, level=0) 

678 if self.S.status.force_done: 

679 return None 

680 

681 # recv values from previous 

682 self.recv_full(comm=comm, level=0, add_to_stats=(k == nsweeps - 1)) 

683 if self.S.status.force_done: 

684 return None 

685 

686 for hook in self.hooks: 

687 hook.pre_sweep(step=self.S, level_number=0) 

688 self.S.levels[0].sweep.update_nodes() 

689 self.S.levels[0].sweep.compute_residual(stage='IT_FINE') 

690 for hook in self.hooks: 

691 hook.post_sweep(step=self.S, level_number=0) 

692 

693 # update stage 

694 self.S.status.stage = 'IT_CHECK' 

695 

696 def it_down(self, comm, num_procs): 

697 """ 

698 Go down the hierarchy from finest to coarsest level 

699 """ 

700 

701 self.S.transfer(source=self.S.levels[0], target=self.S.levels[1]) 

702 

703 # sweep and send on middle levels (not on finest, not on coarsest, though) 

704 for l in range(1, len(self.S.levels) - 1): 

705 nsweeps = self.S.levels[l].params.nsweeps 

706 

707 for _ in range(nsweeps): 

708 self.send_full(comm=comm, level=l) 

709 if self.S.status.force_done: 

710 return None 

711 

712 self.recv_full(comm=comm, level=l) 

713 if self.S.status.force_done: 

714 return None 

715 

716 for hook in self.hooks: 

717 hook.pre_sweep(step=self.S, level_number=l) 

718 self.S.levels[l].sweep.update_nodes() 

719 self.S.levels[l].sweep.compute_residual(stage='IT_DOWN') 

720 for hook in self.hooks: 

721 hook.post_sweep(step=self.S, level_number=l) 

722 

723 # transfer further down the hierarchy 

724 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1]) 

725 

726 # update stage 

727 self.S.status.stage = 'IT_COARSE' 

728 

729 def it_coarse(self, comm, num_procs): 

730 """ 

731 Coarse sweep 

732 """ 

733 

734 # receive from previous step (if not first) 

735 self.recv_full(comm=comm, level=len(self.S.levels) - 1) 

736 if self.S.status.force_done: 

737 return None 

738 

739 # do the sweep 

740 for hook in self.hooks: 

741 hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1) 

742 assert self.S.levels[-1].params.nsweeps == 1, ( 

743 'ERROR: this controller can only work with one sweep on the coarse level, got %s' 

744 % self.S.levels[-1].params.nsweeps 

745 ) 

746 self.S.levels[-1].sweep.update_nodes() 

747 self.S.levels[-1].sweep.compute_residual(stage='IT_COARSE') 

748 for hook in self.hooks: 

749 hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1) 

750 self.S.levels[-1].sweep.compute_end_point() 

751 

752 # send to next step 

753 self.send_full(comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=True) 

754 if self.S.status.force_done: 

755 return None 

756 

757 # update stage 

758 if len(self.S.levels) > 1: # MLSDC or PFASST 

759 self.S.status.stage = 'IT_UP' 

760 else: 

761 self.S.status.stage = 'IT_CHECK' # MSSDC 

762 

763 def it_up(self, comm, num_procs): 

764 """ 

765 Prolong corrections up to finest level (parallel) 

766 """ 

767 

768 # receive and sweep on middle levels (except for coarsest level) 

769 for l in range(len(self.S.levels) - 1, 0, -1): 

770 # prolong values 

771 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1]) 

772 

773 # on middle levels: do sweep as usual 

774 if l - 1 > 0: 

775 nsweeps = self.S.levels[l - 1].params.nsweeps 

776 

777 for k in range(nsweeps): 

778 self.send_full(comm, level=l - 1) 

779 if self.S.status.force_done: 

780 return None 

781 

782 self.recv_full(comm=comm, level=l - 1, add_to_stats=(k == nsweeps - 1)) 

783 if self.S.status.force_done: 

784 return None 

785 

786 for hook in self.hooks: 

787 hook.pre_sweep(step=self.S, level_number=l - 1) 

788 self.S.levels[l - 1].sweep.update_nodes() 

789 self.S.levels[l - 1].sweep.compute_residual(stage='IT_UP') 

790 for hook in self.hooks: 

791 hook.post_sweep(step=self.S, level_number=l - 1) 

792 

793 # update stage 

794 self.S.status.stage = 'IT_FINE' 

795 

796 def default(self, num_procs): 

797 """ 

798 Default routine to catch wrong status 

799 """ 

800 raise ControllerError('Weird stage, got %s' % self.S.status.stage)