Coverage for pySDC/implementations/controller_classes/controller_nonMPI.py: 99%
295 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1import itertools
2import copy as cp
3import numpy as np
4import dill
6from pySDC.core.controller import Controller
7from pySDC.core import step as stepclass
8from pySDC.core.errors import ControllerError, CommunicationError
9from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
12class controller_nonMPI(Controller):
13 """
15 PFASST controller, running serialized version of PFASST in blocks (MG-style)
17 """
19 def __init__(self, num_procs, controller_params, description):
20 """
21 Initialization routine for PFASST controller
23 Args:
24 num_procs: number of parallel time steps (still serial, though), can be 1
25 controller_params: parameter set for the controller and the steps
26 description: all the parameters to set up the rest (levels, problems, transfer, ...)
27 """
29 if 'predict' in controller_params:
30 raise ControllerError('predict flag is ignored, use predict_type instead')
32 # call parent's initialization routine
33 super().__init__(controller_params, description, useMPI=False)
35 self.MS = [stepclass.Step(description)]
37 # try to initialize via dill.copy (much faster for many time-steps)
38 try:
39 for _ in range(num_procs - 1):
40 self.MS.append(dill.copy(self.MS[0]))
41 # if this fails (e.g. due to un-picklable data in the steps), initialize separately
42 except (dill.PicklingError, TypeError, ValueError) as error:
43 self.logger.warning(f'Need to initialize steps separately due to pickling error: {error}')
44 for _ in range(num_procs - 1):
45 self.MS.append(stepclass.Step(description))
47 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=False)]
48 for convergence_controller in self.base_convergence_controllers:
49 self.add_convergence_controller(convergence_controller, description)
51 if self.params.dump_setup:
52 self.dump_setup(step=self.MS[0], controller_params=controller_params, description=description)
54 if num_procs > 1 and len(self.MS[0].levels) > 1:
55 for S in self.MS:
56 for L in S.levels:
57 if not L.sweep.coll.right_is_node:
58 raise ControllerError("For PFASST to work, we assume uend^k = u_M^k")
60 if all(len(S.levels) == len(self.MS[0].levels) for S in self.MS):
61 self.nlevels = len(self.MS[0].levels)
62 else:
63 raise ControllerError('all steps need to have the same number of levels')
65 if self.nlevels == 0:
66 raise ControllerError('need at least one level')
68 self.nsweeps = []
69 for nl in range(self.nlevels):
70 if all(S.levels[nl].params.nsweeps == self.MS[0].levels[nl].params.nsweeps for S in self.MS):
71 self.nsweeps.append(self.MS[0].levels[nl].params.nsweeps)
73 if self.nlevels > 1 and self.nsweeps[-1] > 1:
74 raise ControllerError('this controller cannot do multiple sweeps on coarsest level')
76 if self.nlevels == 1 and self.params.predict_type is not None:
77 self.logger.warning(
78 'you have specified a predictor type but only a single level.. predictor will be ignored'
79 )
81 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
82 C.reset_buffers_nonMPI(self)
83 C.setup_status_variables(self, MS=self.MS)
85 def run(self, u0, t0, Tend):
86 """
87 Main driver for running the serial version of SDC, MSSDC, MLSDC and PFASST (virtual parallelism)
89 Args:
90 u0: initial values
91 t0: starting time
92 Tend: ending time
94 Returns:
95 end values on the finest level
96 stats object containing statistics for each step, each level and each iteration
97 """
99 # some initializations and reset of statistics
100 uend = None
101 num_procs = len(self.MS)
102 for hook in self.hooks:
103 hook.reset_stats()
105 # initial ordering of the steps: 0,1,...,Np-1
106 slots = list(range(num_procs))
108 # initialize time variables of each step
109 time = [t0 + sum(self.MS[j].dt for j in range(p)) for p in slots]
111 # determine which steps are still active (time < Tend)
112 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots]
114 if not any(active):
115 raise ControllerError('Nothing to do, check t0, dt and Tend.')
117 # compress slots according to active steps, i.e. remove all steps which have times above Tend
118 active_slots = list(itertools.compress(slots, active))
120 # initialize block of steps with u0
121 self.restart_block(active_slots, time, u0)
123 for hook in self.hooks:
124 hook.post_setup(step=None, level_number=None)
126 # call pre-run hook
127 for S in self.MS:
128 for hook in self.hooks:
129 hook.pre_run(step=S, level_number=0)
131 # main loop: as long as at least one step is still active (time < Tend), do something
132 while any(active):
133 MS_active = [self.MS[p] for p in active_slots]
134 done = False
135 while not done:
136 done = self.pfasst(MS_active)
138 restarts = [S.status.restart for S in MS_active]
139 restart_at = np.where(restarts)[0][0] if True in restarts else len(MS_active)
140 if True in restarts: # restart part of the block
141 # initial condition to next block is initial condition of step that needs restarting
142 uend = self.MS[restart_at].levels[0].u[0]
143 time[active_slots[0]] = time[restart_at]
144 self.logger.info(f'Starting next block with initial conditions from step {restart_at}')
146 else: # move on to next block
147 # initial condition for next block is last solution of current block
148 uend = self.MS[active_slots[-1]].levels[0].uend
149 time[active_slots[0]] = time[active_slots[-1]] + self.MS[active_slots[-1]].dt
151 for S in MS_active[:restart_at]:
152 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
153 C.post_step_processing(self, S, MS=MS_active)
155 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
156 [C.prepare_next_block(self, S, len(active_slots), time, Tend, MS=MS_active) for S in self.MS]
158 # setup the times of the steps for the next block
159 for i in range(1, len(active_slots)):
160 time[active_slots[i]] = time[active_slots[i] - 1] + self.MS[active_slots[i] - 1].dt
162 # determine new set of active steps and compress slots accordingly
163 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots]
164 active_slots = list(itertools.compress(slots, active))
166 # restart active steps (reset all values and pass uend to u0)
167 self.restart_block(active_slots, time, uend)
169 # call post-run hook
170 for S in self.MS:
171 for hook in self.hooks:
172 hook.post_run(step=S, level_number=0)
174 return uend, self.return_stats()
176 def restart_block(self, active_slots, time, u0):
177 """
178 Helper routine to reset/restart block of (active) steps
180 Args:
181 active_slots: list of active steps
182 time: list of new times
183 u0: initial value to distribute across the steps
185 """
187 # loop over active slots (not directly, since we need the previous entry as well)
188 for j in range(len(active_slots)):
189 # get slot number
190 p = active_slots[j]
192 # store current slot number for diagnostics
193 self.MS[p].status.slot = p
194 # store link to previous step
195 self.MS[p].prev = self.MS[active_slots[j - 1]]
196 # resets step
197 self.MS[p].reset_step()
198 # determine whether I am the first and/or last in line
199 self.MS[p].status.first = active_slots.index(p) == 0
200 self.MS[p].status.last = active_slots.index(p) == len(active_slots) - 1
201 # initialize step with u0
202 self.MS[p].init_step(u0)
203 # reset some values
204 self.MS[p].status.done = False
205 self.MS[p].status.prev_done = False
206 self.MS[p].status.iter = 0
207 self.MS[p].status.stage = 'SPREAD'
208 self.MS[p].status.force_done = False
209 self.MS[p].status.time_size = len(active_slots)
211 for l in self.MS[p].levels:
212 l.tag = None
213 l.status.sweep = 1
215 for p in active_slots:
216 for lvl in self.MS[p].levels:
217 lvl.status.time = time[p]
219 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
220 C.reset_status_variables(self, active_slots=active_slots)
222 def send_full(self, S, level=None, add_to_stats=False):
223 """
224 Function to perform the send, including bookkeeping and logging
226 Args:
227 S: the current step
228 level: the level number
229 add_to_stats: a flag to end recording data in the hooks (defaults to False)
230 """
232 def send(source, tag):
233 """
234 Send function
236 Args:
237 source: level which has the new values
238 tag: identifier for this message
239 """
240 # sending here means computing uend ("one-sided communication")
241 source.sweep.compute_end_point()
242 source.tag = cp.deepcopy(tag)
244 for hook in self.hooks:
245 hook.pre_comm(step=S, level_number=level)
246 if not S.status.last:
247 self.logger.debug(
248 'Process %2i provides data on level %2i with tag %s' % (S.status.slot, level, S.status.iter)
249 )
250 send(S.levels[level], tag=(level, S.status.iter, S.status.slot))
252 for hook in self.hooks:
253 hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
255 def recv_full(self, S, level=None, add_to_stats=False):
256 """
257 Function to perform the recv, including bookkeeping and logging
259 Args:
260 S: the current step
261 level: the level number
262 add_to_stats: a flag to end recording data in the hooks (defaults to False)
263 """
265 def recv(target, source, tag=None):
266 """
267 Receive function
269 Args:
270 target: level which will receive the values
271 source: level which initiated the send
272 tag: identifier to check if this message is really for me
273 """
275 if tag is not None and source.tag != tag:
276 raise CommunicationError('source and target tag are not the same, got %s and %s' % (source.tag, tag))
277 # simply do a deepcopy of the values uend to become the new u0 at the target
278 target.u[0] = target.prob.dtype_u(source.uend)
279 # re-evaluate f on left interval boundary
280 target.f[0] = target.prob.eval_f(target.u[0], target.time)
282 for hook in self.hooks:
283 hook.pre_comm(step=S, level_number=level)
284 if not S.status.prev_done and not S.status.first:
285 self.logger.debug(
286 'Process %2i receives from %2i on level %2i with tag %s'
287 % (S.status.slot, S.prev.status.slot, level, S.status.iter)
288 )
289 recv(S.levels[level], S.prev.levels[level], tag=(level, S.status.iter, S.prev.status.slot))
290 for hook in self.hooks:
291 hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
293 def pfasst(self, local_MS_active):
294 """
295 Main function including the stages of SDC, MLSDC and PFASST (the "controller")
297 For the workflow of this controller, check out one of our PFASST talks or the pySDC paper
299 This method changes self.MS directly by accessing active steps through local_MS_active. Nothing is returned.
301 Args:
302 local_MS_active (list): all active steps
303 """
305 # if all stages are the same (or DONE), continue, otherwise abort
306 stages = [S.status.stage for S in local_MS_active if S.status.stage != 'DONE']
307 if stages[1:] == stages[:-1]:
308 stage = stages[0]
309 else:
310 raise ControllerError('not all stages are equal')
312 self.logger.debug(stage)
314 MS_running = [S for S in local_MS_active if S.status.stage != 'DONE']
316 switcher = {
317 'SPREAD': self.spread,
318 'PREDICT': self.predict,
319 'IT_CHECK': self.it_check,
320 'IT_FINE': self.it_fine,
321 'IT_DOWN': self.it_down,
322 'IT_COARSE': self.it_coarse,
323 'IT_UP': self.it_up,
324 }
326 switcher.get(stage, self.default)(MS_running)
328 return all(S.status.done for S in local_MS_active)
330 def spread(self, local_MS_running):
331 """
332 Spreading phase
334 Args:
335 local_MS_running (list): list of currently running steps
336 """
338 for S in local_MS_running:
339 # first stage: spread values
340 for hook in self.hooks:
341 hook.pre_step(step=S, level_number=0)
343 # call predictor from sweeper
344 S.levels[0].sweep.predict()
346 # update stage
347 if len(S.levels) > 1: # MLSDC or PFASST with predict
348 S.status.stage = 'PREDICT'
349 else:
350 S.status.stage = 'IT_CHECK'
352 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
353 C.post_spread_processing(self, S, MS=local_MS_running)
355 def predict(self, local_MS_running):
356 """
357 Predictor phase
359 Args:
360 local_MS_running (list): list of currently running steps
361 """
363 for S in local_MS_running:
364 for hook in self.hooks:
365 hook.pre_predict(step=S, level_number=0)
367 if self.params.predict_type is None:
368 pass
370 elif self.params.predict_type == 'fine_only':
371 # do a fine sweep only
372 for S in local_MS_running:
373 S.levels[0].sweep.update_nodes()
375 # elif self.params.predict_type == 'libpfasst_style':
376 #
377 # # loop over all steps
378 # for S in local_MS_running:
379 #
380 # # restrict to coarsest level
381 # for l in range(1, len(S.levels)):
382 # S.transfer(source=S.levels[l - 1], target=S.levels[l])
383 #
384 # # run in serial on coarse level
385 # for S in local_MS_running:
386 #
387 # self.hooks.pre_comm(step=S, level_number=len(S.levels) - 1)
388 # # receive from previous step (if not first)
389 # if not S.status.first:
390 # self.logger.debug('Process %2i receives from %2i on level %2i with tag %s -- PREDICT' %
391 # (S.status.slot, S.prev.status.slot, len(S.levels) - 1, 0))
392 # self.recv(S.levels[-1], S.prev.levels[-1], tag=(len(S.levels), 0, S.prev.status.slot))
393 # self.hooks.post_comm(step=S, level_number=len(S.levels) - 1)
394 #
395 # # do the coarse sweep
396 # S.levels[-1].sweep.update_nodes()
397 #
398 # self.hooks.pre_comm(step=S, level_number=len(S.levels) - 1)
399 # # send to succ step
400 # if not S.status.last:
401 # self.logger.debug('Process %2i provides data on level %2i with tag %s -- PREDICT'
402 # % (S.status.slot, len(S.levels) - 1, 0))
403 # self.send(S.levels[-1], tag=(len(S.levels), 0, S.status.slot))
404 # self.hooks.post_comm(step=S, level_number=len(S.levels) - 1, add_to_stats=True)
405 #
406 # # go back to fine level, sweeping
407 # for l in range(self.nlevels - 1, 0, -1):
408 #
409 # for S in local_MS_running:
410 # # prolong values
411 # S.transfer(source=S.levels[l], target=S.levels[l - 1])
412 #
413 # if l - 1 > 0:
414 # S.levels[l - 1].sweep.update_nodes()
415 #
416 # # end with a fine sweep
417 # for S in local_MS_running:
418 # S.levels[0].sweep.update_nodes()
420 elif self.params.predict_type == 'pfasst_burnin':
421 # loop over all steps
422 for S in local_MS_running:
423 # restrict to coarsest level
424 for l in range(1, len(S.levels)):
425 S.transfer(source=S.levels[l - 1], target=S.levels[l])
427 # loop over all steps
428 for q in range(len(local_MS_running)):
429 # loop over last steps: [1,2,3,4], [2,3,4], [3,4], [4]
430 for p in range(q, len(local_MS_running)):
431 S = local_MS_running[p]
433 # do the sweep with new values
434 S.levels[-1].sweep.update_nodes()
436 # send updated values on coarsest level
437 self.send_full(S, level=len(S.levels) - 1)
439 # loop over last steps: [2,3,4], [3,4], [4]
440 for p in range(q + 1, len(local_MS_running)):
441 S = local_MS_running[p]
442 # receive values sent during previous sweep
443 self.recv_full(S, level=len(S.levels) - 1, add_to_stats=(p == len(local_MS_running) - 1))
445 # loop over all steps
446 for S in local_MS_running:
447 # interpolate back to finest level
448 for l in range(len(S.levels) - 1, 0, -1):
449 S.transfer(source=S.levels[l], target=S.levels[l - 1])
451 # send updated values forward
452 self.send_full(S, level=0)
453 # receive values
454 self.recv_full(S, level=0)
456 # end this with a fine sweep
457 for S in local_MS_running:
458 S.levels[0].sweep.update_nodes()
460 elif self.params.predict_type == 'fmg':
461 # TODO: implement FMG predictor
462 raise NotImplementedError('FMG predictor is not yet implemented')
464 else:
465 raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
467 for S in local_MS_running:
468 for hook in self.hooks:
469 hook.post_predict(step=S, level_number=0)
471 for S in local_MS_running:
472 # update stage
473 S.status.stage = 'IT_CHECK'
475 def it_check(self, local_MS_running):
476 """
477 Key routine to check for convergence/termination
479 Args:
480 local_MS_running (list): list of currently running steps
481 """
483 for S in local_MS_running:
484 # send updated values forward
485 self.send_full(S, level=0)
486 # receive values
487 self.recv_full(S, level=0)
488 # compute current residual
489 S.levels[0].sweep.compute_residual(stage='IT_CHECK')
491 for S in local_MS_running:
492 if S.status.iter > 0:
493 for hook in self.hooks:
494 hook.post_iteration(step=S, level_number=0)
496 # decide if the step is done, needs to be restarted and other things convergence related
497 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
498 C.post_iteration_processing(self, S, MS=local_MS_running)
499 C.convergence_control(self, S, MS=local_MS_running)
501 for S in local_MS_running:
502 if not S.status.first:
503 for hook in self.hooks:
504 hook.pre_comm(step=S, level_number=0)
505 S.status.prev_done = S.prev.status.done # "communicate"
506 for hook in self.hooks:
507 hook.post_comm(step=S, level_number=0, add_to_stats=True)
508 S.status.done = S.status.done and S.status.prev_done
510 if self.params.all_to_done:
511 for hook in self.hooks:
512 hook.pre_comm(step=S, level_number=0)
513 S.status.done = all(T.status.done for T in local_MS_running)
514 for hook in self.hooks:
515 hook.post_comm(step=S, level_number=0, add_to_stats=True)
517 if not S.status.done:
518 # increment iteration count here (and only here)
519 S.status.iter += 1
520 for hook in self.hooks:
521 hook.pre_iteration(step=S, level_number=0)
522 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
523 C.pre_iteration_processing(self, S, MS=local_MS_running)
525 if len(S.levels) > 1: # MLSDC or PFASST
526 S.status.stage = 'IT_DOWN'
527 else: # SDC or MSSDC
528 if len(local_MS_running) == 1 or self.params.mssdc_jac: # SDC or parallel MSSDC (Jacobi-like)
529 S.status.stage = 'IT_FINE'
530 else:
531 S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like)
532 else:
533 S.levels[0].sweep.compute_end_point()
534 for hook in self.hooks:
535 hook.post_step(step=S, level_number=0)
536 S.status.stage = 'DONE'
538 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
539 C.reset_buffers_nonMPI(self)
541 def it_fine(self, local_MS_running):
542 """
543 Fine sweeps
545 Args:
546 local_MS_running (list): list of currently running steps
547 """
549 for S in local_MS_running:
550 S.levels[0].status.sweep = 0
552 for k in range(self.nsweeps[0]):
553 for S in local_MS_running:
554 S.levels[0].status.sweep += 1
556 for S in local_MS_running:
557 # send updated values forward
558 self.send_full(S, level=0)
559 # receive values
560 self.recv_full(S, level=0, add_to_stats=(k == self.nsweeps[0] - 1))
562 for S in local_MS_running:
563 # standard sweep workflow: update nodes, compute residual, log progress
564 for hook in self.hooks:
565 hook.pre_sweep(step=S, level_number=0)
566 S.levels[0].sweep.update_nodes()
567 S.levels[0].sweep.compute_residual(stage='IT_FINE')
568 for hook in self.hooks:
569 hook.post_sweep(step=S, level_number=0)
571 for S in local_MS_running:
572 # update stage
573 S.status.stage = 'IT_CHECK'
575 def it_down(self, local_MS_running):
576 """
577 Go down the hierarchy from finest to coarsest level
579 Args:
580 local_MS_running (list): list of currently running steps
581 """
583 for S in local_MS_running:
584 S.transfer(source=S.levels[0], target=S.levels[1])
586 for l in range(1, self.nlevels - 1):
587 # sweep on middle levels (not on finest, not on coarsest, though)
589 for _ in range(self.nsweeps[l]):
590 for S in local_MS_running:
591 # send updated values forward
592 self.send_full(S, level=l)
593 # receive values
594 self.recv_full(S, level=l)
596 for S in local_MS_running:
597 for hook in self.hooks:
598 hook.pre_sweep(step=S, level_number=l)
599 S.levels[l].sweep.update_nodes()
600 S.levels[l].sweep.compute_residual(stage='IT_DOWN')
601 for hook in self.hooks:
602 hook.post_sweep(step=S, level_number=l)
604 for S in local_MS_running:
605 # transfer further down the hierarchy
606 S.transfer(source=S.levels[l], target=S.levels[l + 1])
608 for S in local_MS_running:
609 # update stage
610 S.status.stage = 'IT_COARSE'
612 def it_coarse(self, local_MS_running):
613 """
614 Coarse sweep
616 Args:
617 local_MS_running (list): list of currently running steps
618 """
620 for S in local_MS_running:
621 # receive from previous step (if not first)
622 self.recv_full(S, level=len(S.levels) - 1)
624 # do the sweep
625 for hook in self.hooks:
626 hook.pre_sweep(step=S, level_number=len(S.levels) - 1)
627 S.levels[-1].sweep.update_nodes()
628 S.levels[-1].sweep.compute_residual(stage='IT_COARSE')
629 for hook in self.hooks:
630 hook.post_sweep(step=S, level_number=len(S.levels) - 1)
632 # send to succ step
633 self.send_full(S, level=len(S.levels) - 1, add_to_stats=True)
635 # update stage
636 if len(S.levels) > 1: # MLSDC or PFASST
637 S.status.stage = 'IT_UP'
638 else: # MSSDC
639 S.status.stage = 'IT_CHECK'
641 def it_up(self, local_MS_running):
642 """
643 Prolong corrections up to finest level (parallel)
645 Args:
646 local_MS_running (list): list of currently running steps
647 """
649 for l in range(self.nlevels - 1, 0, -1):
650 for S in local_MS_running:
651 # prolong values
652 S.transfer(source=S.levels[l], target=S.levels[l - 1])
654 # on middle levels: do communication and sweep as usual
655 if l - 1 > 0:
656 for k in range(self.nsweeps[l - 1]):
657 for S in local_MS_running:
658 # send updated values forward
659 self.send_full(S, level=l - 1)
660 # receive values
661 self.recv_full(S, level=l - 1, add_to_stats=(k == self.nsweeps[l - 1] - 1))
663 for S in local_MS_running:
664 for hook in self.hooks:
665 hook.pre_sweep(step=S, level_number=l - 1)
666 S.levels[l - 1].sweep.update_nodes()
667 S.levels[l - 1].sweep.compute_residual(stage='IT_UP')
668 for hook in self.hooks:
669 hook.post_sweep(step=S, level_number=l - 1)
671 for S in local_MS_running:
672 # update stage
673 S.status.stage = 'IT_FINE'
675 def default(self, local_MS_running):
676 """
677 Default routine to catch wrong status
679 Args:
680 local_MS_running (list): list of currently running steps
681 """
682 raise ControllerError('Unknown stage, got %s' % local_MS_running[0].status.stage) # TODO