Coverage for pySDC/implementations/controller_classes/controller_MPI.py: 66%
384 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import numpy as np
2from mpi4py import MPI
4from pySDC.core.controller import Controller
5from pySDC.core.errors import ControllerError
6from pySDC.core.step import Step
7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
10class controller_MPI(Controller):
11 """
13 PFASST controller, running parallel version of PFASST in blocks (MG-style)
15 """
17 def __init__(self, controller_params, description, comm):
18 """
19 Initialization routine for PFASST controller
21 Args:
22 controller_params: parameter set for the controller and the step class
23 description: all the parameters to set up the rest (levels, problems, transfer, ...)
24 comm: MPI communicator
25 """
27 # call parent's initialization routine
28 super().__init__(controller_params, description, useMPI=True)
30 # create single step per processor
31 self.S = Step(description)
33 # pass communicator for future use
34 self.comm = comm
36 num_procs = self.comm.Get_size()
37 rank = self.comm.Get_rank()
39 # insert data on time communicator to the steps (helpful here and there)
40 self.S.status.time_size = num_procs
42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=True)]
43 for convergence_controller in self.base_convergence_controllers:
44 self.add_convergence_controller(convergence_controller, description)
46 if self.params.dump_setup and rank == 0:
47 self.dump_setup(step=self.S, controller_params=controller_params, description=description)
49 num_levels = len(self.S.levels)
51 # add request handler for status send
52 self.req_status = None
53 # add request handle container for isend
54 self.req_send = [None] * num_levels
55 self.req_ibcast = None
56 self.req_diff = None
58 if num_procs > 1 and num_levels > 1:
59 for L in self.S.levels:
60 if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update:
61 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k")
63 if num_levels == 1 and self.params.predict_type is not None:
64 self.logger.warning(
65 'you have specified a predictor type but only a single level.. predictor will be ignored'
66 )
68 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
69 C.setup_status_variables(self, comm=comm)
71 def run(self, u0, t0, Tend):
72 """
73 Main driver for running the parallel version of SDC, MSSDC, MLSDC and PFASST
75 Args:
76 u0: initial values
77 t0: starting time
78 Tend: ending time
80 Returns:
81 end values on the finest level
82 stats object containing statistics for each step, each level and each iteration
83 """
85 # reset stats to prevent double entries from old runs
86 for hook in self.hooks:
87 hook.reset_stats()
89 # setup time initially
90 all_dt = self.comm.allgather(self.S.dt)
91 time = t0 + sum(all_dt[: self.comm.rank])
93 active = time < Tend - 10 * np.finfo(float).eps
94 comm_active = self.comm.Split(active)
95 self.S.status.slot = comm_active.rank
97 if self.comm.rank == 0 and not active:
98 raise ControllerError('Nothing to do, check t0, dt and Tend!')
100 # initialize block of steps with u0
101 self.restart_block(comm_active.size, time, u0, comm=comm_active)
102 uend = u0
104 # call post-setup hook
105 for hook in self.hooks:
106 hook.post_setup(step=None, level_number=None)
108 # call pre-run hook
109 for hook in self.hooks:
110 hook.pre_run(step=self.S, level_number=0)
112 comm_active.Barrier()
114 # while any process still active...
115 while active:
116 while not self.S.status.done:
117 self.pfasst(comm_active, comm_active.size)
119 # determine where to restart
120 restarts = comm_active.allgather(self.S.status.restart)
122 # communicate time and solution to be used as next initial conditions
123 if True in restarts:
124 restart_at = np.where(restarts)[0][0]
125 uend = self.S.levels[0].u[0].bcast(root=restart_at, comm=comm_active)
126 tend = comm_active.bcast(self.S.time, root=restart_at)
127 self.logger.info(f'Starting next block with initial conditions from step {restart_at}')
129 else:
130 uend = self.S.levels[0].uend.bcast(root=comm_active.size - 1, comm=comm_active)
131 tend = comm_active.bcast(self.S.time + self.S.dt, root=comm_active.size - 1)
133 # do convergence controller stuff
134 if not self.S.status.restart:
135 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
136 C.post_step_processing(self, self.S, comm=comm_active)
138 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
139 C.prepare_next_block(self, self.S, self.S.status.time_size, tend, Tend, comm=comm_active)
141 # set new time
142 all_dt = comm_active.allgather(self.S.dt)
143 time = tend + sum(all_dt[: self.S.status.slot])
145 active = time < Tend - 10 * np.finfo(float).eps
147 # check if we need to split the communicator
148 if tend + sum(all_dt[: comm_active.size - 1]) >= Tend - 10 * np.finfo(float).eps:
149 comm_active_new = comm_active.Split(active)
150 comm_active.Free()
151 comm_active = comm_active_new
153 self.S.status.slot = comm_active.rank
155 # initialize block of steps with u0
156 if active:
157 self.restart_block(comm_active.size, time, uend, comm=comm_active)
159 # call post-run hook
160 for hook in self.hooks:
161 hook.post_run(step=self.S, level_number=0)
163 comm_active.Free()
165 return uend, self.return_stats()
167 def restart_block(self, size, time, u0, comm):
168 """
169 Helper routine to reset/restart block of (active) steps
171 Args:
172 size: number of active time steps
173 time: current time
174 u0: initial value to distribute across the steps
175 comm: the communicator
177 Returns:
178 block of (all) steps
179 """
181 # store link to previous step
182 self.S.prev = (self.S.status.slot - 1) % size
183 self.S.next = (self.S.status.slot + 1) % size
185 # resets step
186 self.S.reset_step()
187 # determine whether I am the first and/or last in line
188 self.S.status.first = self.S.prev == size - 1
189 self.S.status.last = self.S.next == 0
190 # initialize step with u0
191 self.S.init_step(u0)
192 # reset some values
193 self.S.status.done = False
194 self.S.status.iter = 0
195 self.S.status.stage = 'SPREAD'
196 for l in self.S.levels:
197 l.tag = None
198 self.req_status = None
199 self.req_diff = None
200 self.req_ibcast = None
201 self.req_diff = None
202 self.req_send = [None] * len(self.S.levels)
203 self.S.status.prev_done = False
204 self.S.status.force_done = False
206 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
207 C.reset_status_variables(self, comm=comm)
209 self.S.status.time_size = size
211 for lvl in self.S.levels:
212 lvl.status.time = time
213 lvl.status.sweep = 1
215 def recv(self, target, source, tag=None, comm=None):
216 """
217 Receive function
219 Args:
220 target: level which will receive the values
221 source: level which initiated the send
222 tag: identifier to check if this message is really for me
223 comm: communicator
224 """
225 req = target.u[0].irecv(source=source, tag=tag, comm=comm)
226 self.wait_with_interrupt(request=req)
227 if self.S.status.force_done:
228 return None
229 # re-evaluate f on left interval boundary
230 target.f[0] = target.prob.eval_f(target.u[0], target.time)
232 def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False):
233 """
234 Function to perform the send, including bookkeeping and logging
236 Args:
237 comm: the communicator
238 blocking: flag to indicate that we need blocking communication
239 level: the level number
240 add_to_stats: a flag to end recording data in the hooks (defaults to False)
241 """
242 for hook in self.hooks:
243 hook.pre_comm(step=self.S, level_number=level)
245 if not blocking:
246 self.wait_with_interrupt(request=self.req_send[level])
247 if self.S.status.force_done:
248 return None
250 self.S.levels[level].sweep.compute_end_point()
252 if not self.S.status.last:
253 self.logger.debug(
254 'isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s'
255 % (
256 self.S.status.slot,
257 self.S.status.stage,
258 self.S.time,
259 self.S.next,
260 level * 100 + self.S.status.iter,
261 self.S.status.iter,
262 )
263 )
264 self.req_send[level] = self.S.levels[level].uend.isend(
265 dest=self.S.next, tag=level * 100 + self.S.status.iter, comm=comm
266 )
267 if blocking:
268 self.wait_with_interrupt(request=self.req_send[level])
269 if self.S.status.force_done:
270 return None
272 for hook in self.hooks:
273 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
275 def recv_full(self, comm, level=None, add_to_stats=False):
276 """
277 Function to perform the recv, including bookkeeping and logging
279 Args:
280 comm: the communicator
281 level: the level number
282 add_to_stats: a flag to end recording data in the hooks (defaults to False)
283 """
285 for hook in self.hooks:
286 hook.pre_comm(step=self.S, level_number=level)
287 if not self.S.status.first and not self.S.status.prev_done:
288 self.logger.debug(
289 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
290 % (
291 self.S.status.slot,
292 self.S.status.stage,
293 self.S.time,
294 self.S.prev,
295 level * 100 + self.S.status.iter,
296 self.S.status.iter,
297 )
298 )
299 self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm)
301 for hook in self.hooks:
302 hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
304 def wait_with_interrupt(self, request):
305 """
306 Wrapper for waiting for the completion of a non-blocking communication, can be interrupted
308 Args:
309 request: request to wait for
310 """
311 if request is not None and self.req_ibcast is not None:
312 while not request.Test():
313 if self.req_ibcast.Test():
314 self.logger.debug(f'{self.S.status.slot} has been cancelled during {self.S.status.stage}..')
315 self.S.status.stage = f'CANCELLED_{self.S.status.stage}'
316 self.S.status.force_done = True
317 return None
318 if request is not None:
319 request.Wait()
321 def check_iteration_estimate(self, comm):
322 """
323 Routine to compute and check error/iteration estimation
325 Args:
326 comm: time-communicator
327 """
329 # Compute diff between old and new values
330 diff_new = 0.0
331 L = self.S.levels[0]
333 for m in range(1, L.sweep.coll.num_nodes + 1):
334 diff_new = max(diff_new, abs(L.uold[m] - L.u[m]))
336 # Send forward diff
337 for hook in self.hooks:
338 hook.pre_comm(step=self.S, level_number=0)
340 self.wait_with_interrupt(request=self.req_diff)
341 if self.S.status.force_done:
342 return None
344 if not self.S.status.first:
345 prev_diff = np.empty(1, dtype=float)
346 req = comm.Irecv((prev_diff, MPI.DOUBLE), source=self.S.prev, tag=999)
347 self.wait_with_interrupt(request=req)
348 if self.S.status.force_done:
349 return None
350 self.logger.debug(
351 'recv diff: status %s, process %s, time %s, source %s, tag %s, iter %s'
352 % (prev_diff, self.S.status.slot, self.S.time, self.S.prev, 999, self.S.status.iter)
353 )
354 diff_new = max(prev_diff[0], diff_new)
356 if not self.S.status.last:
357 self.logger.debug(
358 'isend diff: status %s, process %s, time %s, target %s, tag %s, iter %s'
359 % (diff_new, self.S.status.slot, self.S.time, self.S.next, 999, self.S.status.iter)
360 )
361 tmp = np.array(diff_new, dtype=float)
362 self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999)
364 for hook in self.hooks:
365 hook.post_comm(step=self.S, level_number=0)
367 # Store values from first iteration
368 if self.S.status.iter == 1:
369 self.S.status.diff_old_loc = diff_new
370 self.S.status.diff_first_loc = diff_new
371 # Compute iteration estimate
372 elif self.S.status.iter > 1:
373 Ltilde_loc = min(diff_new / self.S.status.diff_old_loc, 0.9)
374 self.S.status.diff_old_loc = diff_new
375 alpha = 1 / (1 - Ltilde_loc) * self.S.status.diff_first_loc
376 Kest_loc = np.log(self.S.params.errtol / alpha) / np.log(Ltilde_loc) * 1.05 # Safety factor!
377 self.logger.debug(
378 f'LOCAL: {L.time:8.4f}, {self.S.status.iter}: {int(np.ceil(Kest_loc))}, '
379 f'{Ltilde_loc:8.6e}, {Kest_loc:8.6e}, '
380 f'{Ltilde_loc ** self.S.status.iter * alpha:8.6e}'
381 )
382 Kest_glob = Kest_loc
383 # If condition is met, send interrupt
384 if np.ceil(Kest_glob) <= self.S.status.iter:
385 if self.S.status.last:
386 self.logger.debug(f'{self.S.status.slot} is done, broadcasting..')
387 for hook in self.hooks:
388 hook.pre_comm(step=self.S, level_number=0)
389 comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait()
390 for hook in self.hooks:
391 hook.post_comm(step=self.S, level_number=0, add_to_stats=True)
392 self.logger.debug(f'{self.S.status.slot} is done, broadcasting done')
393 self.S.status.done = True
394 else:
395 for hook in self.hooks:
396 hook.pre_comm(step=self.S, level_number=0)
397 for hook in self.hooks:
398 hook.post_comm(step=self.S, level_number=0, add_to_stats=True)
400 def pfasst(self, comm, num_procs):
401 """
402 Main function including the stages of SDC, MLSDC and PFASST (the "controller")
404 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper
406 Args:
407 comm: communicator
408 num_procs (int): number of parallel processes
409 """
411 stage = self.S.status.stage
413 self.logger.debug(stage + ' - process ' + str(self.S.status.slot))
415 # Wait for interrupt, if iteration estimator is used
416 if self.params.use_iteration_estimator and stage == 'SPREAD' and not self.S.status.last:
417 done = np.empty(1)
418 self.req_ibcast = comm.Ibcast((done, MPI.INT), root=comm.Get_size() - 1)
420 # If interrupt is there, cleanup and finish
421 if self.params.use_iteration_estimator and not self.S.status.last and self.req_ibcast.Test():
422 self.logger.debug(f'{self.S.status.slot} is done..')
423 self.S.status.done = True
425 if not stage == 'IT_CHECK':
426 self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..')
427 self.S.levels[0].u[1:] = self.S.levels[0].uold[1:]
429 for hook in self.hooks:
430 hook.post_iteration(step=self.S, level_number=0)
432 for req in self.req_send:
433 if req is not None and req != MPI.REQUEST_NULL:
434 req.Cancel()
435 if self.req_status is not None and self.req_status != MPI.REQUEST_NULL:
436 self.req_status.Cancel()
437 if self.req_diff is not None and self.req_diff != MPI.REQUEST_NULL:
438 self.req_diff.Cancel()
440 self.S.status.stage = 'DONE'
441 for hook in self.hooks:
442 hook.post_step(step=self.S, level_number=0)
444 else:
445 # Start cycling, if not interrupted
446 switcher = {
447 'SPREAD': self.spread,
448 'PREDICT': self.predict,
449 'IT_CHECK': self.it_check,
450 'IT_FINE': self.it_fine,
451 'IT_DOWN': self.it_down,
452 'IT_COARSE': self.it_coarse,
453 'IT_UP': self.it_up,
454 }
456 switcher.get(stage, self.default)(comm, num_procs)
458 def spread(self, comm, num_procs):
459 """
460 Spreading phase
461 """
463 # first stage: spread values
464 for hook in self.hooks:
465 hook.pre_step(step=self.S, level_number=0)
467 # call predictor from sweeper
468 self.S.levels[0].sweep.predict()
470 if self.params.use_iteration_estimator:
471 # store previous iterate to compute difference later on
472 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:]
474 # update stage
475 if len(self.S.levels) > 1: # MLSDC or PFASST with predict
476 self.S.status.stage = 'PREDICT'
477 else:
478 self.S.status.stage = 'IT_CHECK'
480 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
481 C.post_spread_processing(self, self.S, comm=comm)
483 def predict(self, comm, num_procs):
484 """
485 Predictor phase
486 """
488 for hook in self.hooks:
489 hook.pre_predict(step=self.S, level_number=0)
491 if self.params.predict_type is None:
492 pass
494 elif self.params.predict_type == 'fine_only':
495 # do a fine sweep only
496 self.S.levels[0].sweep.update_nodes()
498 # elif self.params.predict_type == 'libpfasst_style':
499 #
500 # # restrict to coarsest level
501 # for l in range(1, len(self.S.levels)):
502 # self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l])
503 #
504 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
505 # if not self.S.status.first:
506 # self.logger.debug('recv data predict: process %s, stage %s, time, %s, source %s, tag %s' %
507 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
508 # self.S.status.iter))
509 # self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
510 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1)
511 #
512 # # do the sweep with new values
513 # self.S.levels[-1].sweep.update_nodes()
514 # self.S.levels[-1].sweep.compute_end_point()
515 #
516 # self.hooks.pre_comm(step=self.S, level_number=len(self.S.levels) - 1)
517 # if not self.S.status.last:
518 # self.logger.debug('send data predict: process %s, stage %s, time, %s, target %s, tag %s' %
519 # (self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
520 # self.S.status.iter))
521 # self.S.levels[-1].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm).Wait()
522 # self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True)
523 #
524 # # go back to fine level, sweeping
525 # for l in range(len(self.S.levels) - 1, 0, -1):
526 # # prolong values
527 # self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
528 # # on middle levels: do sweep as usual
529 # if l - 1 > 0:
530 # self.S.levels[l - 1].sweep.update_nodes()
531 #
532 # # end with a fine sweep
533 # self.S.levels[0].sweep.update_nodes()
535 elif self.params.predict_type == 'pfasst_burnin':
536 # restrict to coarsest level
537 for l in range(1, len(self.S.levels)):
538 self.S.transfer(source=self.S.levels[l - 1], target=self.S.levels[l])
540 for p in range(self.S.status.slot + 1):
541 if not p == 0:
542 self.recv_full(comm=comm, level=len(self.S.levels) - 1)
543 if self.S.status.force_done:
544 return None
546 # do the sweep with new values
547 self.S.levels[-1].sweep.update_nodes()
548 self.S.levels[-1].sweep.compute_end_point()
550 self.send_full(
551 comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=(p == self.S.status.slot)
552 )
553 if self.S.status.force_done:
554 return None
556 # interpolate back to finest level
557 for l in range(len(self.S.levels) - 1, 0, -1):
558 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
560 self.send_full(comm=comm, level=0)
561 if self.S.status.force_done:
562 return None
564 self.recv_full(comm=comm, level=0)
565 if self.S.status.force_done:
566 return None
568 # end this with a fine sweep
569 self.S.levels[0].sweep.update_nodes()
571 elif self.params.predict_type == 'fmg':
572 # TODO: implement FMG predictor
573 raise NotImplementedError('FMG predictor is not yet implemented')
575 else:
576 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
578 for hook in self.hooks:
579 hook.post_predict(step=self.S, level_number=0)
581 # update stage
582 self.S.status.stage = 'IT_CHECK'
584 def it_check(self, comm, num_procs):
585 """
586 Key routine to check for convergence/termination
587 """
589 # Update values to compute the residual
590 self.send_full(comm=comm, level=0)
591 if self.S.status.force_done:
592 return None
594 self.recv_full(comm=comm, level=0)
595 if self.S.status.force_done:
596 return None
598 # compute the residual
599 self.S.levels[0].sweep.compute_residual(stage='IT_CHECK')
601 if self.params.use_iteration_estimator:
602 # TODO: replace with convergence controller
603 self.check_iteration_estimate(comm=comm)
605 if self.S.status.force_done:
606 return None
608 if self.S.status.iter > 0:
609 for hook in self.hooks:
610 hook.post_iteration(step=self.S, level_number=0)
612 # decide if the step is done, needs to be restarted and other things convergence related
613 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
614 C.post_iteration_processing(self, self.S, comm=comm)
615 C.convergence_control(self, self.S, comm=comm)
617 # if not ready, keep doing stuff
618 if not self.S.status.done:
619 # increment iteration count here (and only here)
620 self.S.status.iter += 1
622 for hook in self.hooks:
623 hook.pre_iteration(step=self.S, level_number=0)
624 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
625 C.pre_iteration_processing(self, self.S, comm=comm)
627 if self.params.use_iteration_estimator:
628 # store previous iterate to compute difference later on
629 self.S.levels[0].uold[1:] = self.S.levels[0].u[1:]
631 if len(self.S.levels) > 1: # MLSDC or PFASST
632 self.S.status.stage = 'IT_DOWN'
633 else:
634 if num_procs == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like)
635 self.S.status.stage = 'IT_FINE'
636 else:
637 self.S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like)
639 else:
640 if not self.params.use_iteration_estimator:
641 # Need to finish all pending isend requests. These will occur for the first active process, since
642 # in the last iteration the wait statement will not be called ("send and forget")
643 for req in self.req_send:
644 if req is not None:
645 req.Wait()
646 if self.req_status is not None:
647 self.req_status.Wait()
648 if self.req_diff is not None:
649 self.req_diff.Wait()
650 else:
651 for req in self.req_send:
652 if req is not None:
653 req.Cancel()
654 if self.req_status is not None:
655 self.req_status.Cancel()
656 if self.req_diff is not None:
657 self.req_diff.Cancel()
659 for hook in self.hooks:
660 hook.post_step(step=self.S, level_number=0)
661 self.S.status.stage = 'DONE'
663 def it_fine(self, comm, num_procs):
664 """
665 Fine sweeps
666 """
668 nsweeps = self.S.levels[0].params.nsweeps
670 self.S.levels[0].status.sweep = 0
672 # do fine sweep
673 for k in range(nsweeps):
674 self.S.levels[0].status.sweep += 1
676 # send values forward
677 self.send_full(comm=comm, level=0)
678 if self.S.status.force_done:
679 return None
681 # recv values from previous
682 self.recv_full(comm=comm, level=0, add_to_stats=(k == nsweeps - 1))
683 if self.S.status.force_done:
684 return None
686 for hook in self.hooks:
687 hook.pre_sweep(step=self.S, level_number=0)
688 self.S.levels[0].sweep.update_nodes()
689 self.S.levels[0].sweep.compute_residual(stage='IT_FINE')
690 for hook in self.hooks:
691 hook.post_sweep(step=self.S, level_number=0)
693 # update stage
694 self.S.status.stage = 'IT_CHECK'
696 def it_down(self, comm, num_procs):
697 """
698 Go down the hierarchy from finest to coarsest level
699 """
701 self.S.transfer(source=self.S.levels[0], target=self.S.levels[1])
703 # sweep and send on middle levels (not on finest, not on coarsest, though)
704 for l in range(1, len(self.S.levels) - 1):
705 nsweeps = self.S.levels[l].params.nsweeps
707 for _ in range(nsweeps):
708 self.send_full(comm=comm, level=l)
709 if self.S.status.force_done:
710 return None
712 self.recv_full(comm=comm, level=l)
713 if self.S.status.force_done:
714 return None
716 for hook in self.hooks:
717 hook.pre_sweep(step=self.S, level_number=l)
718 self.S.levels[l].sweep.update_nodes()
719 self.S.levels[l].sweep.compute_residual(stage='IT_DOWN')
720 for hook in self.hooks:
721 hook.post_sweep(step=self.S, level_number=l)
723 # transfer further down the hierarchy
724 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1])
726 # update stage
727 self.S.status.stage = 'IT_COARSE'
729 def it_coarse(self, comm, num_procs):
730 """
731 Coarse sweep
732 """
734 # receive from previous step (if not first)
735 self.recv_full(comm=comm, level=len(self.S.levels) - 1)
736 if self.S.status.force_done:
737 return None
739 # do the sweep
740 for hook in self.hooks:
741 hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
742 assert self.S.levels[-1].params.nsweeps == 1, (
743 'ERROR: this controller can only work with one sweep on the coarse level, got %s'
744 % self.S.levels[-1].params.nsweeps
745 )
746 self.S.levels[-1].sweep.update_nodes()
747 self.S.levels[-1].sweep.compute_residual(stage='IT_COARSE')
748 for hook in self.hooks:
749 hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1)
750 self.S.levels[-1].sweep.compute_end_point()
752 # send to next step
753 self.send_full(comm=comm, blocking=True, level=len(self.S.levels) - 1, add_to_stats=True)
754 if self.S.status.force_done:
755 return None
757 # update stage
758 if len(self.S.levels) > 1: # MLSDC or PFASST
759 self.S.status.stage = 'IT_UP'
760 else:
761 self.S.status.stage = 'IT_CHECK' # MSSDC
763 def it_up(self, comm, num_procs):
764 """
765 Prolong corrections up to finest level (parallel)
766 """
768 # receive and sweep on middle levels (except for coarsest level)
769 for l in range(len(self.S.levels) - 1, 0, -1):
770 # prolong values
771 self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
773 # on middle levels: do sweep as usual
774 if l - 1 > 0:
775 nsweeps = self.S.levels[l - 1].params.nsweeps
777 for k in range(nsweeps):
778 self.send_full(comm, level=l - 1)
779 if self.S.status.force_done:
780 return None
782 self.recv_full(comm=comm, level=l - 1, add_to_stats=(k == nsweeps - 1))
783 if self.S.status.force_done:
784 return None
786 for hook in self.hooks:
787 hook.pre_sweep(step=self.S, level_number=l - 1)
788 self.S.levels[l - 1].sweep.update_nodes()
789 self.S.levels[l - 1].sweep.compute_residual(stage='IT_UP')
790 for hook in self.hooks:
791 hook.post_sweep(step=self.S, level_number=l - 1)
793 # update stage
794 self.S.status.stage = 'IT_FINE'
796 def default(self, num_procs):
797 """
798 Default routine to catch wrong status
799 """
800 raise ControllerError('Weird stage, got %s' % self.S.status.stage)