Coverage for pySDC/implementations/controller_classes/controller_nonMPI.py: 99%
298 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import itertools
2import copy as cp
3import numpy as np
4import dill
6from pySDC.core.controller import Controller
7from pySDC.core import step as stepclass
8from pySDC.core.errors import ControllerError, CommunicationError
9from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
12class controller_nonMPI(Controller):
13 """
15 PFASST controller, running serialized version of PFASST in blocks (MG-style)
17 """
19 def __init__(self, num_procs, controller_params, description):
20 """
21 Initialization routine for PFASST controller
23 Args:
24 num_procs: number of parallel time steps (still serial, though), can be 1
25 controller_params: parameter set for the controller and the steps
26 description: all the parameters to set up the rest (levels, problems, transfer, ...)
27 """
29 if 'predict' in controller_params:
30 raise ControllerError('predict flag is ignored, use predict_type instead')
32 # call parent's initialization routine
33 super().__init__(controller_params, description, useMPI=False)
35 self.MS = [stepclass.Step(description)]
37 # try to initialize via dill.copy (much faster for many time-steps)
38 try:
39 for _ in range(num_procs - 1):
40 self.MS.append(dill.copy(self.MS[0]))
41 # if this fails (e.g. due to un-picklable data in the steps), initialize separately
42 except (dill.PicklingError, TypeError, ValueError) as error:
43 self.logger.warning(f'Need to initialize steps separately due to pickling error: {error}')
44 for _ in range(num_procs - 1):
45 self.MS.append(stepclass.Step(description))
47 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=False)]
48 for convergence_controller in self.base_convergence_controllers:
49 self.add_convergence_controller(convergence_controller, description)
51 if self.params.dump_setup:
52 self.dump_setup(step=self.MS[0], controller_params=controller_params, description=description)
54 if num_procs > 1 and len(self.MS[0].levels) > 1:
55 for S in self.MS:
56 for L in S.levels:
57 if not L.sweep.coll.right_is_node:
58 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k")
60 if all(len(S.levels) == len(self.MS[0].levels) for S in self.MS):
61 self.nlevels = len(self.MS[0].levels)
62 else:
63 raise ControllerError('all steps need to have the same number of levels')
65 if self.nlevels == 0:
66 raise ControllerError('need at least one level')
68 self.nsweeps = []
69 for nl in range(self.nlevels):
70 if all(S.levels[nl].params.nsweeps == self.MS[0].levels[nl].params.nsweeps for S in self.MS):
71 self.nsweeps.append(self.MS[0].levels[nl].params.nsweeps)
73 if self.nlevels > 1 and self.nsweeps[-1] > 1:
74 raise ControllerError('this controller cannot do multiple sweeps on coarsest level')
76 if self.nlevels == 1 and self.params.predict_type is not None:
77 self.logger.warning(
78 'you have specified a predictor type but only a single level.. predictor will be ignored'
79 )
81 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
82 C.reset_buffers_nonMPI(self)
83 C.setup_status_variables(self, MS=self.MS)
85 def run(self, u0, t0, Tend):
86 """
87 Main driver for running the serial version of SDC, MSSDC, MLSDC and PFASST (virtual parallelism)
89 Args:
90 u0: initial values
91 t0: starting time
92 Tend: ending time
94 Returns:
95 end values on the finest level
96 stats object containing statistics for each step, each level and each iteration
97 """
99 # some initializations and reset of statistics
100 uend = None
101 num_procs = len(self.MS)
102 for hook in self.hooks:
103 hook.reset_stats()
105 # initial ordering of the steps: 0,1,...,Np-1
106 slots = list(range(num_procs))
108 # initialize time variables of each step
109 time = [t0 + sum(self.MS[j].dt for j in range(p)) for p in slots]
111 # determine which steps are still active (time < Tend)
112 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots]
114 if not any(active):
115 raise ControllerError('Nothing to do, check t0, dt and Tend.')
117 # compress slots according to active steps, i.e. remove all steps which have times above Tend
118 active_slots = list(itertools.compress(slots, active))
120 # initialize block of steps with u0
121 self.restart_block(active_slots, time, u0)
123 for hook in self.hooks:
124 hook.post_setup(step=None, level_number=None)
126 # call pre-run hook
127 for S in self.MS:
128 for hook in self.hooks:
129 hook.pre_run(step=S, level_number=0)
131 # main loop: as long as at least one step is still active (time < Tend), do something
132 while any(active):
133 MS_active = [self.MS[p] for p in active_slots]
134 done = False
135 while not done:
136 done = self.pfasst(MS_active)
138 restarts = [S.status.restart for S in MS_active]
139 restart_at = np.where(restarts)[0][0] if True in restarts else len(MS_active)
140 if True in restarts: # restart part of the block
141 # initial condition to next block is initial condition of step that needs restarting
142 uend = self.MS[restart_at].levels[0].u[0]
143 time[active_slots[0]] = time[restart_at]
144 self.logger.info(f'Starting next block with initial conditions from step {restart_at}')
146 else: # move on to next block
147 # initial condition for next block is last solution of current block
148 uend = self.MS[active_slots[-1]].levels[0].uend
149 time[active_slots[0]] = time[active_slots[-1]] + self.MS[active_slots[-1]].dt
151 for S in MS_active[:restart_at]:
152 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
153 C.post_step_processing(self, S, MS=MS_active)
155 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
156 [C.prepare_next_block(self, S, len(active_slots), time, Tend, MS=MS_active) for S in self.MS]
158 # setup the times of the steps for the next block
159 for i in range(1, len(active_slots)):
160 time[active_slots[i]] = time[active_slots[i] - 1] + self.MS[active_slots[i] - 1].dt
162 # determine new set of active steps and compress slots accordingly
163 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots]
164 active_slots = list(itertools.compress(slots, active))
166 # restart active steps (reset all values and pass uend to u0)
167 self.restart_block(active_slots, time, uend)
169 # call post-run hook
170 for S in self.MS:
171 for hook in self.hooks:
172 hook.post_run(step=S, level_number=0)
174 for S in self.MS:
175 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
176 C.post_run_processing(self, S, MS=MS_active)
178 return uend, self.return_stats()
180 def restart_block(self, active_slots, time, u0):
181 """
182 Helper routine to reset/restart block of (active) steps
184 Args:
185 active_slots: list of active steps
186 time: list of new times
187 u0: initial value to distribute across the steps
189 """
191 # loop over active slots (not directly, since we need the previous entry as well)
192 for j in range(len(active_slots)):
193 # get slot number
194 p = active_slots[j]
196 # store current slot number for diagnostics
197 self.MS[p].status.slot = p
198 # store link to previous step
199 self.MS[p].prev = self.MS[active_slots[j - 1]]
200 # resets step
201 self.MS[p].reset_step()
202 # determine whether I am the first and/or last in line
203 self.MS[p].status.first = active_slots.index(p) == 0
204 self.MS[p].status.last = active_slots.index(p) == len(active_slots) - 1
205 # initialize step with u0
206 self.MS[p].init_step(u0)
207 # reset some values
208 self.MS[p].status.done = False
209 self.MS[p].status.prev_done = False
210 self.MS[p].status.iter = 0
211 self.MS[p].status.stage = 'SPREAD'
212 self.MS[p].status.force_done = False
213 self.MS[p].status.time_size = len(active_slots)
215 for l in self.MS[p].levels:
216 l.tag = None
217 l.status.sweep = 1
219 for p in active_slots:
220 for lvl in self.MS[p].levels:
221 lvl.status.time = time[p]
223 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
224 C.reset_status_variables(self, active_slots=active_slots)
226 def send_full(self, S, level=None, add_to_stats=False):
227 """
228 Function to perform the send, including bookkeeping and logging
230 Args:
231 S: the current step
232 level: the level number
233 add_to_stats: a flag to end recording data in the hooks (defaults to False)
234 """
236 def send(source, tag):
237 """
238 Send function
240 Args:
241 source: level which has the new values
242 tag: identifier for this message
243 """
244 # sending here means computing uend ("one-sided communication")
245 source.sweep.compute_end_point()
246 source.tag = cp.deepcopy(tag)
248 for hook in self.hooks:
249 hook.pre_comm(step=S, level_number=level)
250 if not S.status.last:
251 self.logger.debug(
252 'Process %2i provides data on level %2i with tag %s' % (S.status.slot, level, S.status.iter)
253 )
254 send(S.levels[level], tag=(level, S.status.iter, S.status.slot))
256 for hook in self.hooks:
257 hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
259 def recv_full(self, S, level=None, add_to_stats=False):
260 """
261 Function to perform the recv, including bookkeeping and logging
263 Args:
264 S: the current step
265 level: the level number
266 add_to_stats: a flag to end recording data in the hooks (defaults to False)
267 """
269 def recv(target, source, tag=None):
270 """
271 Receive function
273 Args:
274 target: level which will receive the values
275 source: level which initiated the send
276 tag: identifier to check if this message is really for me
277 """
279 if tag is not None and source.tag != tag:
280 raise CommunicationError('source and target tag are not the same, got %s and %s' % (source.tag, tag))
281 # simply do a deepcopy of the values uend to become the new u0 at the target
282 target.u[0] = target.prob.dtype_u(source.uend)
283 # re-evaluate f on left interval boundary
284 target.f[0] = target.prob.eval_f(target.u[0], target.time)
286 for hook in self.hooks:
287 hook.pre_comm(step=S, level_number=level)
288 if not S.status.prev_done and not S.status.first:
289 self.logger.debug(
290 'Process %2i receives from %2i on level %2i with tag %s'
291 % (S.status.slot, S.prev.status.slot, level, S.status.iter)
292 )
293 recv(S.levels[level], S.prev.levels[level], tag=(level, S.status.iter, S.prev.status.slot))
294 for hook in self.hooks:
295 hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
297 def pfasst(self, local_MS_active):
298 """
299 Main function including the stages of SDC, MLSDC and PFASST (the "controller")
301 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper
303 This method changes self.MS directly by accessing active steps through local_MS_active. Nothing is returned.
305 Args:
306 local_MS_active (list): all active steps
307 """
309 # if all stages are the same (or DONE), continue, otherwise abort
310 stages = [S.status.stage for S in local_MS_active if S.status.stage != 'DONE']
311 if stages[1:] == stages[:-1]:
312 stage = stages[0]
313 else:
314 raise ControllerError('not all stages are equal')
316 self.logger.debug(stage)
318 MS_running = [S for S in local_MS_active if S.status.stage != 'DONE']
320 switcher = {
321 'SPREAD': self.spread,
322 'PREDICT': self.predict,
323 'IT_CHECK': self.it_check,
324 'IT_FINE': self.it_fine,
325 'IT_DOWN': self.it_down,
326 'IT_COARSE': self.it_coarse,
327 'IT_UP': self.it_up,
328 }
330 switcher.get(stage, self.default)(MS_running)
332 return all(S.status.done for S in local_MS_active)
334 def spread(self, local_MS_running):
335 """
336 Spreading phase
338 Args:
339 local_MS_running (list): list of currently running steps
340 """
342 for S in local_MS_running:
343 # first stage: spread values
344 for hook in self.hooks:
345 hook.pre_step(step=S, level_number=0)
347 # call predictor from sweeper
348 S.levels[0].sweep.predict()
350 # update stage
351 if len(S.levels) > 1: # MLSDC or PFASST with predict
352 S.status.stage = 'PREDICT'
353 else:
354 S.status.stage = 'IT_CHECK'
356 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
357 C.post_spread_processing(self, S, MS=local_MS_running)
359 def predict(self, local_MS_running):
360 """
361 Predictor phase
363 Args:
364 local_MS_running (list): list of currently running steps
365 """
367 for S in local_MS_running:
368 for hook in self.hooks:
369 hook.pre_predict(step=S, level_number=0)
371 if self.params.predict_type is None:
372 pass
374 elif self.params.predict_type == 'fine_only':
375 # do a fine sweep only
376 for S in local_MS_running:
377 S.levels[0].sweep.update_nodes()
379 # elif self.params.predict_type == 'libpfasst_style':
380 #
381 # # loop over all steps
382 # for S in local_MS_running:
383 #
384 # # restrict to coarsest level
385 # for l in range(1, len(S.levels)):
386 # S.transfer(source=S.levels[l - 1], target=S.levels[l])
387 #
388 # # run in serial on coarse level
389 # for S in local_MS_running:
390 #
391 # self.hooks.pre_comm(step=S, level_number=len(S.levels) - 1)
392 # # receive from previous step (if not first)
393 # if not S.status.first:
394 # self.logger.debug('Process %2i receives from %2i on level %2i with tag %s -- PREDICT' %
395 # (S.status.slot, S.prev.status.slot, len(S.levels) - 1, 0))
396 # self.recv(S.levels[-1], S.prev.levels[-1], tag=(len(S.levels), 0, S.prev.status.slot))
397 # self.hooks.post_comm(step=S, level_number=len(S.levels) - 1)
398 #
399 # # do the coarse sweep
400 # S.levels[-1].sweep.update_nodes()
401 #
402 # self.hooks.pre_comm(step=S, level_number=len(S.levels) - 1)
403 # # send to succ step
404 # if not S.status.last:
405 # self.logger.debug('Process %2i provides data on level %2i with tag %s -- PREDICT'
406 # % (S.status.slot, len(S.levels) - 1, 0))
407 # self.send(S.levels[-1], tag=(len(S.levels), 0, S.status.slot))
408 # self.hooks.post_comm(step=S, level_number=len(S.levels) - 1, add_to_stats=True)
409 #
410 # # go back to fine level, sweeping
411 # for l in range(self.nlevels - 1, 0, -1):
412 #
413 # for S in local_MS_running:
414 # # prolong values
415 # S.transfer(source=S.levels[l], target=S.levels[l - 1])
416 #
417 # if l - 1 > 0:
418 # S.levels[l - 1].sweep.update_nodes()
419 #
420 # # end with a fine sweep
421 # for S in local_MS_running:
422 # S.levels[0].sweep.update_nodes()
424 elif self.params.predict_type == 'pfasst_burnin':
425 # loop over all steps
426 for S in local_MS_running:
427 # restrict to coarsest level
428 for l in range(1, len(S.levels)):
429 S.transfer(source=S.levels[l - 1], target=S.levels[l])
431 # loop over all steps
432 for q in range(len(local_MS_running)):
433 # loop over last steps: [1,2,3,4], [2,3,4], [3,4], [4]
434 for p in range(q, len(local_MS_running)):
435 S = local_MS_running[p]
437 # do the sweep with new values
438 S.levels[-1].sweep.update_nodes()
440 # send updated values on coarsest level
441 self.send_full(S, level=len(S.levels) - 1)
443 # loop over last steps: [2,3,4], [3,4], [4]
444 for p in range(q + 1, len(local_MS_running)):
445 S = local_MS_running[p]
446 # receive values sent during previous sweep
447 self.recv_full(S, level=len(S.levels) - 1, add_to_stats=(p == len(local_MS_running) - 1))
449 # loop over all steps
450 for S in local_MS_running:
451 # interpolate back to finest level
452 for l in range(len(S.levels) - 1, 0, -1):
453 S.transfer(source=S.levels[l], target=S.levels[l - 1])
455 # send updated values forward
456 self.send_full(S, level=0)
457 # receive values
458 self.recv_full(S, level=0)
460 # end this with a fine sweep
461 for S in local_MS_running:
462 S.levels[0].sweep.update_nodes()
464 elif self.params.predict_type == 'fmg':
465 # TODO: implement FMG predictor
466 raise NotImplementedError('FMG predictor is not yet implemented')
468 else:
469 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
471 for S in local_MS_running:
472 for hook in self.hooks:
473 hook.post_predict(step=S, level_number=0)
475 for S in local_MS_running:
476 # update stage
477 S.status.stage = 'IT_CHECK'
479 def it_check(self, local_MS_running):
480 """
481 Key routine to check for convergence/termination
483 Args:
484 local_MS_running (list): list of currently running steps
485 """
487 for S in local_MS_running:
488 # send updated values forward
489 self.send_full(S, level=0)
490 # receive values
491 self.recv_full(S, level=0)
492 # compute current residual
493 S.levels[0].sweep.compute_residual(stage='IT_CHECK')
495 for S in local_MS_running:
496 if S.status.iter > 0:
497 for hook in self.hooks:
498 hook.post_iteration(step=S, level_number=0)
500 # decide if the step is done, needs to be restarted and other things convergence related
501 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
502 C.post_iteration_processing(self, S, MS=local_MS_running)
503 C.convergence_control(self, S, MS=local_MS_running)
505 for S in local_MS_running:
506 if not S.status.first:
507 for hook in self.hooks:
508 hook.pre_comm(step=S, level_number=0)
509 S.status.prev_done = S.prev.status.done # "communicate"
510 for hook in self.hooks:
511 hook.post_comm(step=S, level_number=0, add_to_stats=True)
512 S.status.done = S.status.done and S.status.prev_done
514 if self.params.all_to_done:
515 for hook in self.hooks:
516 hook.pre_comm(step=S, level_number=0)
517 S.status.done = all(T.status.done for T in local_MS_running)
518 for hook in self.hooks:
519 hook.post_comm(step=S, level_number=0, add_to_stats=True)
521 if not S.status.done:
522 # increment iteration count here (and only here)
523 S.status.iter += 1
524 for hook in self.hooks:
525 hook.pre_iteration(step=S, level_number=0)
526 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
527 C.pre_iteration_processing(self, S, MS=local_MS_running)
529 if len(S.levels) > 1: # MLSDC or PFASST
530 S.status.stage = 'IT_DOWN'
531 else: # SDC or MSSDC
532 if len(local_MS_running) == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like)
533 S.status.stage = 'IT_FINE'
534 else:
535 S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like)
536 else:
537 S.levels[0].sweep.compute_end_point()
538 for hook in self.hooks:
539 hook.post_step(step=S, level_number=0)
540 S.status.stage = 'DONE'
542 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
543 C.reset_buffers_nonMPI(self)
545 def it_fine(self, local_MS_running):
546 """
547 Fine sweeps
549 Args:
550 local_MS_running (list): list of currently running steps
551 """
553 for S in local_MS_running:
554 S.levels[0].status.sweep = 0
556 for k in range(self.nsweeps[0]):
557 for S in local_MS_running:
558 S.levels[0].status.sweep += 1
560 for S in local_MS_running:
561 # send updated values forward
562 self.send_full(S, level=0)
563 # receive values
564 self.recv_full(S, level=0, add_to_stats=(k == self.nsweeps[0] - 1))
566 for S in local_MS_running:
567 # standard sweep workflow: update nodes, compute residual, log progress
568 for hook in self.hooks:
569 hook.pre_sweep(step=S, level_number=0)
570 S.levels[0].sweep.update_nodes()
571 S.levels[0].sweep.compute_residual(stage='IT_FINE')
572 for hook in self.hooks:
573 hook.post_sweep(step=S, level_number=0)
575 for S in local_MS_running:
576 # update stage
577 S.status.stage = 'IT_CHECK'
579 def it_down(self, local_MS_running):
580 """
581 Go down the hierarchy from finest to coarsest level
583 Args:
584 local_MS_running (list): list of currently running steps
585 """
587 for S in local_MS_running:
588 S.transfer(source=S.levels[0], target=S.levels[1])
590 for l in range(1, self.nlevels - 1):
591 # sweep on middle levels (not on finest, not on coarsest, though)
593 for _ in range(self.nsweeps[l]):
594 for S in local_MS_running:
595 # send updated values forward
596 self.send_full(S, level=l)
597 # receive values
598 self.recv_full(S, level=l)
600 for S in local_MS_running:
601 for hook in self.hooks:
602 hook.pre_sweep(step=S, level_number=l)
603 S.levels[l].sweep.update_nodes()
604 S.levels[l].sweep.compute_residual(stage='IT_DOWN')
605 for hook in self.hooks:
606 hook.post_sweep(step=S, level_number=l)
608 for S in local_MS_running:
609 # transfer further down the hierarchy
610 S.transfer(source=S.levels[l], target=S.levels[l + 1])
612 for S in local_MS_running:
613 # update stage
614 S.status.stage = 'IT_COARSE'
616 def it_coarse(self, local_MS_running):
617 """
618 Coarse sweep
620 Args:
621 local_MS_running (list): list of currently running steps
622 """
624 for S in local_MS_running:
625 # receive from previous step (if not first)
626 self.recv_full(S, level=len(S.levels) - 1)
628 # do the sweep
629 for hook in self.hooks:
630 hook.pre_sweep(step=S, level_number=len(S.levels) - 1)
631 S.levels[-1].sweep.update_nodes()
632 S.levels[-1].sweep.compute_residual(stage='IT_COARSE')
633 for hook in self.hooks:
634 hook.post_sweep(step=S, level_number=len(S.levels) - 1)
636 # send to succ step
637 self.send_full(S, level=len(S.levels) - 1, add_to_stats=True)
639 # update stage
640 if len(S.levels) > 1: # MLSDC or PFASST
641 S.status.stage = 'IT_UP'
642 else: # MSSDC
643 S.status.stage = 'IT_CHECK'
645 def it_up(self, local_MS_running):
646 """
647 Prolong corrections up to finest level (parallel)
649 Args:
650 local_MS_running (list): list of currently running steps
651 """
653 for l in range(self.nlevels - 1, 0, -1):
654 for S in local_MS_running:
655 # prolong values
656 S.transfer(source=S.levels[l], target=S.levels[l - 1])
658 # on middle levels: do communication and sweep as usual
659 if l - 1 > 0:
660 for k in range(self.nsweeps[l - 1]):
661 for S in local_MS_running:
662 # send updated values forward
663 self.send_full(S, level=l - 1)
664 # receive values
665 self.recv_full(S, level=l - 1, add_to_stats=(k == self.nsweeps[l - 1] - 1))
667 for S in local_MS_running:
668 for hook in self.hooks:
669 hook.pre_sweep(step=S, level_number=l - 1)
670 S.levels[l - 1].sweep.update_nodes()
671 S.levels[l - 1].sweep.compute_residual(stage='IT_UP')
672 for hook in self.hooks:
673 hook.post_sweep(step=S, level_number=l - 1)
675 for S in local_MS_running:
676 # update stage
677 S.status.stage = 'IT_FINE'
679 def default(self, local_MS_running):
680 """
681 Default routine to catch wrong status
683 Args:
684 local_MS_running (list): list of currently running steps
685 """
686 raise ControllerError('Unknown stage, got %s' % local_MS_running[0].status.stage) # TODO