Coverage for pySDC/implementations/controller_classes/controller_MPI.py: 66%
386 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
2from mpi4py import MPI
4from pySDC.core.controller import Controller
5from pySDC.core.errors import ControllerError
6from pySDC.core.step import Step
7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
10class controller_MPI(Controller):
11 """
13 PFASST controller, running parallel version of PFASST in blocks (MG-style)
15 """
17 def __init__(self, controller_params, description, comm):
18 """
19 Initialization routine for PFASST controller
21 Args:
22 controller_params: parameter set for the controller and the step class
23 description: all the parameters to set up the rest (levels, problems, transfer, ...)
24 comm: MPI communicator
25 """
27 # call parent's initialization routine
28 super().__init__(controller_params, description, useMPI=True)
30 # create single step per processor
31 self.S = Step(description)
33 # pass communicator for future use
34 self.comm = comm
36 num_procs = self.comm.Get_size()
37 rank = self.comm.Get_rank()
39 # insert data on time communicator to the steps (helpful here and there)
40 self.S.status.time_size = num_procs
42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=True)]
43 for convergence_controller in self.base_convergence_controllers:
44 self.add_convergence_controller(convergence_controller, description)
46 if self.params.dump_setup and rank == 0:
47 self.dump_setup(step=self.S, controller_params=controller_params, description=description)
49 num_levels = len(self.S.levels)
51 # add request handler for status send
52 self.req_status = None
53 # add request handle container for isend
54 self.req_send = [None] * num_levels
55 self.req_ibcast = None
56 self.req_diff = None
58 if num_procs > 1 and num_levels > 1:
59 for L in self.S.levels:
60 if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update:
61 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k")
63 if num_levels == 1 and self.params.predict_type is not None:
64 self.logger.warning(
65 'you have specified a predictor type but only a single level.. predictor will be ignored'
66 )
68 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
69 C.setup_status_variables(self, comm=comm)
71 def run(self, u0, t0, Tend):
72 """
73 Main driver for running the parallel version of SDC, MSSDC, MLSDC and PFASST
75 Args:
76 u0: initial values
77 t0: starting time
78 Tend: ending time
80 Returns:
81 end values on the finest level
82 stats object containing statistics for each step, each level and each iteration
83 """
85 # reset stats to prevent double entries from old runs
86 for hook in self.hooks:
87 hook.reset_stats()
89 # setup time initially
90 all_dt = self.comm.allgather(self.S.dt)
91 time = t0 + sum(all_dt[: self.comm.rank])
93 active = time < Tend - 10 * np.finfo(float).eps
94 comm_active = self.comm.Split(active)
95 self.S.status.slot = comm_active.rank
97 if self.comm.rank == 0 and not active:
98 raise ControllerError('Nothing to do, check t0, dt and Tend!')
100 # initialize block of steps with u0
101 self.restart_block(comm_active.size, time, u0, comm=comm_active)
102 uend = u0
104 # call post-setup hook
105 for hook in self.hooks:
106 hook.post_setup(step=None, level_number=None)
108 # call pre-run hook
109 for hook in self.hooks:
110 hook.pre_run(step=self.S, level_number=0)
112 comm_active.Barrier()
114 # while any process still active...
115 while active:
116 while not self.S.status.done:
117 self.pfasst(comm_active, comm_active.size)
119 # determine where to restart
120 restarts = comm_active.allgather(self.S.status.restart)
122 # communicate time and solution to be used as next initial conditions
123 if True in restarts:
124 restart_at = np.where(restarts)[0][0]
125 uend = self.S.levels[0].u[0].bcast(root=restart_at, comm=comm_active)
126 tend = comm_active.bcast(self.S.time, root=restart_at)
127 self.logger.info(f'Starting next block with initial conditions from step {restart_at}')
129 else:
130 uend = self.S.levels[0].uend.bcast(root=comm_active.size - 1, comm=comm_active)
131 tend = comm_active.bcast(self.S.time + self.S.dt, root=comm_active.size - 1)
133 # do convergence controller stuff
134 if not self.S.status.restart:
135 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
136 C.post_step_processing(self, self.S, comm=comm_active)
138 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
139 C.prepare_next_block(self, self.S, self.S.status.time_size, tend, Tend, comm=comm_active)
141 # set new time
142 all_dt = comm_active.allgather(self.S.dt)
143 time = tend + sum(all_dt[: self.S.status.slot])
145 active = time < Tend - 10 * np.finfo(float).eps
147 # check if we need to split the communicator
148 if tend + sum(all_dt[: comm_active.size - 1]) >= Tend - 10 * np.finfo(float).eps:
149 comm_active_new = comm_active.Split(active)
150 comm_active.Free()
151 comm_active = comm_active_new
153 self.S.status.slot = comm_active.rank
155 # initialize block of steps with u0
156 if active:
157 self.restart_block(comm_active.size, time, uend, comm=comm_active)
159 # call post-run hook
160 for hook in self.hooks:
161 hook.post_run(step=self.S, level_number=0)
163 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
164 C.post_run_processing(self, self.S, comm=self.comm)
166 comm_active.Free()
168 return uend, self.return_stats()
170 def restart_block(self, size, time, u0, comm):
171 """
172 Helper routine to reset/restart block of (active) steps
174 Args:
175 size: number of active time steps
176 time: current time
177 u0: initial value to distribute across the steps
178 comm: the communicator
180 Returns:
181 block of (all) steps
182 """
184 # store link to previous step
185 self.S.prev = (self.S.status.slot - 1) % size
186 self.S.next = (self.S.status.slot + 1) % size
188 # resets step
189 self.S.reset_step()
190 # determine whether I am the first and/or last in line
191 self.S.status.first = self.S.prev == size - 1
192 self.S.status.last = self.S.next == 0
193 # initialize step with u0
194 self.S.init_step(u0)
195 # reset some values
196 self.S.status.done = False
197 self.S.status.iter = 0
198 self.S.status.stage = 'SPREAD'
199 for l in self.S.levels:
200 l.tag = None
201 self.req_status = None
202 self.req_diff = None
203 self.req_ibcast = None
204 self.req_diff = None
205 self.req_send = [None] * len(self.S.levels)
206 self.S.status.prev_done = False
207 self.S.status.force_done = False
209 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
210 C.reset_status_variables(self, comm=comm)
212 self.S.status.time_size = size
214 for lvl in self.S.levels:
215 lvl.status.time = time
216 lvl.status.sweep = 1
218 def recv(self, target, source, tag=None, comm=None):
219 """
220 Receive function
222 Args:
223 target: level which will receive the values
224 source: level which initiated the send
225 tag: identifier to check if this message is really for me
226 comm: communicator
227 """
228 req = target.u[0].irecv(source=source, tag=tag, comm=comm)
229 self.wait_with_interrupt(request=req)
230 if self.S.status.force_done:
231 return None
232 # re-evaluate f on left interval boundary
233 target.f[0] = target.prob.eval_f(target.u[0], target.time)
235 def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False):
236 """
237 Function to perform the send, including bookkeeping and logging
239 Args:
240 comm: the communicator
241 blocking: flag to indicate that we need blocking communication
242 level: the level number
243 add_to_stats: a flag to end recording data in the hooks (defaults to False)
244 """
245 for hook in self.hooks:
246 hook.pre_comm(step=self.S, level_number=level)
248 if not blocking:
249 self.wait_with_interrupt(request=self.req_send[level])
250 if self.S.status.force_done:
251 return None
253 self.S.levels[level].sweep.compute_end_point()
255 if not self.S.status.last:
256 self.logger.debug(
257 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
258 % (
259 self.S.status.slot,
260 self.S.status.stage,
261 self.S.time,
262 self.S.next,
263 level * 100 + self.S.status.iter,
264 self.S.status.iter,
265 )
266 )
267 self.req_send[level] = self.S.levels[level].uend.isend(
268 dest=self.S.next, tag=level * 100 + self.S.status.iter, comm=comm
269 )
270 if blocking:
271 self.wait_with_interrupt(request=self.req_send[level])
272 if self.S.status.force_done:
273 return None
275 for hook in self.hooks:
276 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
278 def recv_full(self, comm, level=None, add_to_stats=False):
279 """
280 Function to perform the recv, including bookkeeping and logging
282 Args:
283 comm: the communicator
284 level: the level number
285 add_to_stats: a flag to end recording data in the hooks (defaults to False)
286 """
288 for hook in self.hooks:
289 hook.pre_comm(step=self.S, level_number=level)
290 if not self.S.status.first and not self.S.status.prev_done:
291 self.logger.debug(
292 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
293 % (
294 self.S.status.slot,
295 self.S.status.stage,
296 self.S.time,
297 self.S.prev,
298 level * 100 + self.S.status.iter,
299 self.S.status.iter,
300 )
301 )
302 self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm)
304 for hook in self.hooks:
305 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
307 def wait_with_interrupt(self, request):
308 """
309 Wrapper for waiting for the completion of a non-blocking communication, can be interrupted
311 Args:
312 request: request to wait for
313 """
314 if request is not None and self.req_ibcast is not None:
315 while not request.Test():
316 if self.req_ibcast.Test():
317 self.logger.debug(f'{self.S.status.slot} has been cancelled during {self.S.status.stage}..')
318 self.S.status.stage = f'CANCELLED_{self.S.status.stage}'
319 self.S.status.force_done = True
320 return None
321 if request is not None:
322 request.Wait()
324 def check_iteration_estimate(self, comm):
325 """
326 Routine to compute and check error/iteration estimation
328 Args:
329 comm: time-communicator
330 """
332 # Compute diff between old and new values
333 diff_new = 0.0
334 L = self.S.levels[0]
336 for m in range(1, L.sweep.coll.num_nodes + 1):
337 diff_new = max(diff_new, abs(L.uold[m] - L.u[m]))
339 # Send forward diff
340 for hook in self.hooks:
341 hook.pre_comm(step=self.S, level_number=0)
343 self.wait_with_interrupt(request=self.req_diff)
344 if self.S.status.force_done:
345 return None
347 if not self.S.status.first:
348 prev_diff = np.empty(1, dtype=float)
349 req = comm.Irecv((prev_diff, MPI.DOUBLE), source=self.S.prev, tag=999)
350 self.wait_with_interrupt(request=req)
351 if self.S.status.force_done:
352 return None
353 self.logger.debug(
354 'recv diff: status %s, process %s, time %s, source %s, tag %s, iter %s'
355 % (prev_diff, self.S.status.slot, self.S.time, self.S.prev, 999, self.S.status.iter)
356 )
357 diff_new = max(prev_diff[0], diff_new)
359 if not self.S.status.last:
360 self.logger.debug(
361 'isend diff: status %s, process %s, time %s, target %s, tag %s, iter %s'
362 % (diff_new, self.S.status.slot, self.S.time, self.S.next, 999, self.S.status.iter)
363 )
364 tmp = np.array(diff_new, dtype=float)
365 self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999)
367 for hook in self.hooks:
368 hook.post_comm(step=self.S, level_number=0)
370 # Store values from first iteration
371 if self.S.status.iter == 1:
372 self.S.status.diff_old_loc = diff_new
373 self.S.status.diff_first_loc = diff_new
374 # Compute iteration estimate
375 elif self.S.status.iter > 1:
376 Ltilde_loc = min(diff_new / self.S.status.diff_old_loc, 0.9)
377 self.S.status.diff_old_loc = diff_new
378 alpha = 1 / (1 - Ltilde_loc) * self.S.status.diff_first_loc
379 Kest_loc = np.log(self.S.params.errtol / alpha) / np.log(Ltilde_loc) * 1.05 # Safety factor!
380 self.logger.debug(
381 f'LOCAL: {L.time:8.4f}, {self.S.status.iter}: {int(np.ceil(Kest_loc))}, '
382 f'{Ltilde_loc:8.6e}, {Kest_loc:8.6e}, '
383 f'{Ltilde_loc ** self.S.status.iter * alpha:8.6e}'
384 )
385 Kest_glob = Kest_loc
386 # If condition is met, send interrupt
387 if np.ceil(Kest_glob) <= self.S.status.iter:
388 if self.S.status.last:
389 self.logger.debug(f'{self.S.status.slot} is done, broadcasting..')
390 for hook in self.hooks:
391 hook.pre_comm(step=self.S, level_number=0)
392 comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait()
393 for hook in self.hooks:
394 hook.post_comm(step=self.S, level_number=0, add_to_stats=True)
395 self.logger.debug(f'{self.S.status.slot} is done, broadcasting done')
396 self.S.status.done = True
397 else:
398 for hook in self.hooks:
399 hook.pre_comm(step=self.S, level_number=0)
400 for hook in self.hooks:
401 hook.post_comm(step=self.S, level_number=0, add_to_stats=True)
403 def pfasst(self, comm, num_procs):
404 """
405 Main function including the stages of SDC, MLSDC and PFASST (the "controller")
407 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper
409 Args:
410 comm: communicator
411 num_procs (int): number of parallel processes
412 """
414 stage = self.S.status.stage
416 self.logger.debug(stage + ' - process ' + str(self.S.status.slot))
418 # Wait for interrupt, if iteration estimator is used
419 if self.params.use_iteration_estimator and stage == 'SPREAD' and not self.S.status.last:
420 done = np.empty(1)
421 self.req_ibcast = comm.Ibcast((done, MPI.INT), root=comm.Get_size() - 1)
423 # If interrupt is there, cleanup and finish
424 if self.params.use_iteration_estimator and not self.S.status.last and self.req_ibcast.Test():
425 self.logger.debug(f'{self.S.status.slot} is done..')
426 self.S.status.done = True
428 if not stage == 'IT_CHECK':
429 self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..')
430 self.S.levels[0].u[1:] = self.S.levels[0].uold[1:]
432 for hook in self.hooks:
433 hook.post_iteration(step=self.S, level_number=0)
435 for req in self.req_send:
436 if req is not None and req != MPI.REQUEST_NULL:
437 req.Cancel()
438 if self.req_status is not None and self.req_status != MPI.REQUEST_NULL:
439 self.req_status.Cancel()
440 if self.req_diff is not None and self.req_diff != MPI.REQUEST_NULL:
441 self.req_diff.Cancel()
443 self.S.status.stage = 'DONE'
444 for hook in self.hooks:
445 hook.post_step(step=self.S, level_number=0)
447 else:
448 # Start cycling, if not interrupted
449 switcher = {
450 'SPREAD': self.spread,
451 'PREDICT': self.predict,
452 'IT_CHECK': self.it_check,
453 'IT_FINE': self.it_fine,
454 'IT_DOWN': self.it_down,
455 'IT_COARSE': self.it_coarse,
456 'IT_UP': self.it_up,
457 }
459 switcher.get(stage, self.default)(comm, num_procs)
461 def spread(self, comm, num_procs):
462 """
463 Spreading phase
464 """
466 # first stage: spread values
467 for hook in self.hooks:
468 hook.pre_step(step=self.S, level_number=0)
470 # call predictor from sweeper
471 self.S.levels[0].sweep.predict()
473 if self.params.use_iteration_estimator:
474 # store previous iterate to compute difference later on
475 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:]
477 # update stage
478 if len(self.S.levels) > 1: # MLSDC or PFASST with predict
479 self.S.status.stage = 'PREDICT'
480 else:
481 self.S.status.stage = 'IT_CHECK'
483 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
484 C.post_spread_processing(self, self.S, comm=comm)
486 def predict(self, comm, num_procs):
487 """
488 Predictor phase
489 """
491 for hook in self.hooks:
492 hook.pre_predict(step=self.S, level_number=0)
494 if self.params.predict_type is None:
495 pass
497 elif self.params.predict_type == 'fine_only':
498 # do a fine sweep only
499 self.S.levels[0].sweep.update_nodes()
501 # elif self.params.predict_type == 'libpfasst_style':
502 #
503 # # restrict to coarsest level
504 # for l in range(1, len(self.S.levels)):
505 # self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l])
506 #
507 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
508 # if not self.S.status.first:
509 # self.logger.debug('recv data predict: process %s, stage %s, time, %s, source %s, tag %s' %
510 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
511 # self.S.status.iter))
512 # self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
513 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1)
514 #
515 # # do the sweep with new values
516 # self.S.levels[-1].sweep.update_nodes()
517 # self.S.levels[-1].sweep.compute_end_point()
518 #
519 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
520 # if not self.S.status.last:
521 # self.logger.debug('send data predict: process %s, stage %s, time, %s, target %s, tag %s' %
522 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
523 # self.S.status.iter))
524 # self.S.levels[-1].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm).Wait()
525 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True)
526 #
527 # # go back to fine level, sweeping
528 # for l in range(len(self.S.levels) - 1, 0, -1):
529 # # prolong values
530 # self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
531 # # on middle levels: do sweep as usual
532 # if l - 1 > 0:
533 # self.S.levels[l - 1].sweep.update_nodes()
534 #
535 # # end with a fine sweep
536 # self.S.levels[0].sweep.update_nodes()
538 elif self.params.predict_type == 'pfasst_burnin':
539 # restrict to coarsest level
540 for l in range(1, len(self.S.levels)):
541 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l])
543 for p in range(self.S.status.slot + 1):
544 if not p == 0:
545 self.recv_full(comm=comm, level=len(self.S.levels) - 1)
546 if self.S.status.force_done:
547 return None
549 # do the sweep with new values
550 self.S.levels[-1].sweep.update_nodes()
551 self.S.levels[-1].sweep.compute_end_point()
553 self.send_full(
554 comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=(p == self.S.status.slot)
555 )
556 if self.S.status.force_done:
557 return None
559 # interpolate back to finest level
560 for l in range(len(self.S.levels) - 1, 0, -1):
561 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
563 self.send_full(comm=comm, level=0)
564 if self.S.status.force_done:
565 return None
567 self.recv_full(comm=comm, level=0)
568 if self.S.status.force_done:
569 return None
571 # end this with a fine sweep
572 self.S.levels[0].sweep.update_nodes()
574 elif self.params.predict_type == 'fmg':
575 # TODO: implement FMG predictor
576 raise NotImplementedError('FMG predictor is not yet implemented')
578 else:
579 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
581 for hook in self.hooks:
582 hook.post_predict(step=self.S, level_number=0)
584 # update stage
585 self.S.status.stage = 'IT_CHECK'
587 def it_check(self, comm, num_procs):
588 """
589 Key routine to check for convergence/termination
590 """
592 # Update values to compute the residual
593 self.send_full(comm=comm, level=0)
594 if self.S.status.force_done:
595 return None
597 self.recv_full(comm=comm, level=0)
598 if self.S.status.force_done:
599 return None
601 # compute the residual
602 self.S.levels[0].sweep.compute_residual(stage='IT_CHECK')
604 if self.params.use_iteration_estimator:
605 # TODO: replace with convergence controller
606 self.check_iteration_estimate(comm=comm)
608 if self.S.status.force_done:
609 return None
611 if self.S.status.iter > 0:
612 for hook in self.hooks:
613 hook.post_iteration(step=self.S, level_number=0)
615 # decide if the step is done, needs to be restarted and other things convergence related
616 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
617 C.post_iteration_processing(self, self.S, comm=comm)
618 C.convergence_control(self, self.S, comm=comm)
620 # if not ready, keep doing stuff
621 if not self.S.status.done:
622 # increment iteration count here (and only here)
623 self.S.status.iter += 1
625 for hook in self.hooks:
626 hook.pre_iteration(step=self.S, level_number=0)
627 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
628 C.pre_iteration_processing(self, self.S, comm=comm)
630 if self.params.use_iteration_estimator:
631 # store previous iterate to compute difference later on
632 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:]
634 if len(self.S.levels) > 1: # MLSDC or PFASST
635 self.S.status.stage = 'IT_DOWN'
636 else:
637 if num_procs == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like)
638 self.S.status.stage = 'IT_FINE'
639 else:
640 self.S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like)
642 else:
643 if not self.params.use_iteration_estimator:
644 # Need to finish all pending isend requests. These will occur for the first active process, since
645 # in the last iteration the wait statement will not be called ("send and forget")
646 for req in self.req_send:
647 if req is not None:
648 req.Wait()
649 if self.req_status is not None:
650 self.req_status.Wait()
651 if self.req_diff is not None:
652 self.req_diff.Wait()
653 else:
654 for req in self.req_send:
655 if req is not None:
656 req.Cancel()
657 if self.req_status is not None:
658 self.req_status.Cancel()
659 if self.req_diff is not None:
660 self.req_diff.Cancel()
662 for hook in self.hooks:
663 hook.post_step(step=self.S, level_number=0)
664 self.S.status.stage = 'DONE'
666 def it_fine(self, comm, num_procs):
667 """
668 Fine sweeps
669 """
671 nsweeps = self.S.levels[0].params.nsweeps
673 self.S.levels[0].status.sweep = 0
675 # do fine sweep
676 for k in range(nsweeps):
677 self.S.levels[0].status.sweep += 1
679 # send values forward
680 self.send_full(comm=comm, level=0)
681 if self.S.status.force_done:
682 return None
684 # recv values from previous
685 self.recv_full(comm=comm, level=0, add_to_stats=(k == nsweeps - 1))
686 if self.S.status.force_done:
687 return None
689 for hook in self.hooks:
690 hook.pre_sweep(step=self.S, level_number=0)
691 self.S.levels[0].sweep.update_nodes()
692 self.S.levels[0].sweep.compute_residual(stage='IT_FINE')
693 for hook in self.hooks:
694 hook.post_sweep(step=self.S, level_number=0)
696 # update stage
697 self.S.status.stage = 'IT_CHECK'
699 def it_down(self, comm, num_procs):
700 """
701 Go down the hierarchy from finest to coarsest level
702 """
704 self.S.transfer(source=self.S.levels[0], target=self.S.levels[1])
706 # sweep and send on middle levels (not on finest, not on coarsest, though)
707 for l in range(1, len(self.S.levels) - 1):
708 nsweeps = self.S.levels[l].params.nsweeps
710 for _ in range(nsweeps):
711 self.send_full(comm=comm, level=l)
712 if self.S.status.force_done:
713 return None
715 self.recv_full(comm=comm, level=l)
716 if self.S.status.force_done:
717 return None
719 for hook in self.hooks:
720 hook.pre_sweep(step=self.S, level_number=l)
721 self.S.levels[l].sweep.update_nodes()
722 self.S.levels[l].sweep.compute_residual(stage='IT_DOWN')
723 for hook in self.hooks:
724 hook.post_sweep(step=self.S, level_number=l)
726 # transfer further down the hierarchy
727 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1])
729 # update stage
730 self.S.status.stage = 'IT_COARSE'
732 def it_coarse(self, comm, num_procs):
733 """
734 Coarse sweep
735 """
737 # receive from previous step (if not first)
738 self.recv_full(comm=comm, level=len(self.S.levels) - 1)
739 if self.S.status.force_done:
740 return None
742 # do the sweep
743 for hook in self.hooks:
744 hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
745 assert self.S.levels[-1].params.nsweeps == 1, (
746 'ERROR: this controller can only work with one sweep on the coarse level, got %s'
747 % self.S.levels[-1].params.nsweeps
748 )
749 self.S.levels[-1].sweep.update_nodes()
750 self.S.levels[-1].sweep.compute_residual(stage='IT_COARSE')
751 for hook in self.hooks:
752 hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1)
753 self.S.levels[-1].sweep.compute_end_point()
755 # send to next step
756 self.send_full(comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=True)
757 if self.S.status.force_done:
758 return None
760 # update stage
761 if len(self.S.levels) > 1: # MLSDC or PFASST
762 self.S.status.stage = 'IT_UP'
763 else:
764 self.S.status.stage = 'IT_CHECK' # MSSDC
766 def it_up(self, comm, num_procs):
767 """
768 Prolong corrections up to finest level (parallel)
769 """
771 # receive and sweep on middle levels (except for coarsest level)
772 for l in range(len(self.S.levels) - 1, 0, -1):
773 # prolong values
774 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
776 # on middle levels: do sweep as usual
777 if l - 1 > 0:
778 nsweeps = self.S.levels[l - 1].params.nsweeps
780 for k in range(nsweeps):
781 self.send_full(comm, level=l - 1)
782 if self.S.status.force_done:
783 return None
785 self.recv_full(comm=comm, level=l - 1, add_to_stats=(k == nsweeps - 1))
786 if self.S.status.force_done:
787 return None
789 for hook in self.hooks:
790 hook.pre_sweep(step=self.S, level_number=l - 1)
791 self.S.levels[l - 1].sweep.update_nodes()
792 self.S.levels[l - 1].sweep.compute_residual(stage='IT_UP')
793 for hook in self.hooks:
794 hook.post_sweep(step=self.S, level_number=l - 1)
796 # update stage
797 self.S.status.stage = 'IT_FINE'
799 def default(self, num_procs):
800 """
801 Default routine to catch wrong status
802 """
803 raise ControllerError('Weird stage, got %s' % self.S.status.stage)