Coverage for pySDC/implementations/controller_classes/controller_ParaDiag_nonMPI.py: 97%
214 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-21 06:49 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-21 06:49 +0000
1import itertools
2import numpy as np
4from pySDC.core.controller import ParaDiagController
5from pySDC.core import step as stepclass
6from pySDC.core.errors import ControllerError
7from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
8from pySDC.helpers.ParaDiagHelper import get_G_inv_matrix
11class controller_ParaDiag_nonMPI(ParaDiagController):
12 """
14 ParaDiag controller, running serialized version.
16 This controller uses the increment formulation. That is to say, we setup the residual of the all at once problem,
17 put it on the right hand side, invert the ParaDiag preconditioner on the left-hand side to compute the increment
18 and then add the increment onto the solution. For this reason, we need to replace the solution values in the steps
19 with the residual values before the solves and then put the solution plus increment back into the steps. This is a
20 bit counter to what you expect when you access the `u` variable in the levels, but it is mathematically advantageous.
21 """
23 def __init__(self, num_procs, controller_params, description):
24 """
25 Initialization routine for ParaDiag controller
27 Args:
28 num_procs: number of parallel time steps (still serial, though), can be 1
29 controller_params: parameter set for the controller and the steps
30 description: all the parameters to set up the rest (levels, problems, transfer, ...)
31 """
32 super().__init__(controller_params, description, useMPI=False, n_steps=num_procs)
34 self.MS = []
36 for l in range(num_procs):
37 G_inv = get_G_inv_matrix(l, num_procs, self.params.alpha, description['sweeper_params'])
38 description['sweeper_params']['G_inv'] = G_inv
40 self.MS.append(stepclass.Step(description))
42 self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=False)]
43 for convergence_controller in self.base_convergence_controllers:
44 self.add_convergence_controller(convergence_controller, description)
46 if self.params.dump_setup:
47 self.dump_setup(step=self.MS[0], controller_params=controller_params, description=description)
49 if len(self.MS[0].levels) > 1:
50 raise NotImplementedError('This controller does not support multiple levels')
52 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
53 C.reset_buffers_nonMPI(self)
54 C.setup_status_variables(self, MS=self.MS)
56 def ParaDiag(self, local_MS_active):
57 """
58 Main function for ParaDiag
60 For the workflow of this controller, see https://arxiv.org/abs/2103.12571
62 This method changes self.MS directly by accessing active steps through local_MS_active.
64 Args:
65 local_MS_active (list): all active steps
67 Returns:
68 boot: Whether all steps are done
69 """
71 # if all stages are the same (or DONE), continue, otherwise abort
72 stages = [S.status.stage for S in local_MS_active if S.status.stage != 'DONE']
73 if stages[1:] == stages[:-1]:
74 stage = stages[0]
75 else:
76 raise ControllerError('not all stages are equal')
78 self.logger.debug(stage)
80 MS_running = [S for S in local_MS_active if S.status.stage != 'DONE']
82 switcher = {
83 'SPREAD': self.spread,
84 'IT_CHECK': self.it_check,
85 'IT_PARADIAG': self.it_ParaDiag,
86 }
88 assert stage in switcher.keys(), f'Got unexpected stage {stage!r}'
89 switcher[stage](MS_running)
91 return all(S.status.done for S in local_MS_active)
93 def apply_matrix(self, mat, quantity):
94 """
95 Apply a matrix on the step level. Needs to be square. Puts the result back into the controller.
97 Args:
98 mat: square LxL matrix with L number of steps
99 """
100 L = len(self.MS)
101 assert np.allclose(mat.shape, L)
102 assert len(mat.shape) == 2
104 level = self.MS[0].levels[0]
105 M = level.sweep.params.num_nodes
106 prob = level.prob
108 # buffer for storing the result
109 res = [
110 None,
111 ] * L
113 if quantity == 'residual':
114 me = [S.levels[0].residual for S in self.MS]
115 elif quantity == 'increment':
116 me = [S.levels[0].increment for S in self.MS]
117 else:
118 raise NotImplementedError
120 # compute matrix-vector product
121 for i in range(mat.shape[0]):
122 res[i] = [prob.u_init for _ in range(M)]
123 for j in range(mat.shape[1]):
124 for m in range(M):
125 res[i][m] += mat[i, j] * me[j][m]
127 # put the result in the "output"
128 for i in range(mat.shape[0]):
129 for m in range(M):
130 me[i][m] = res[i][m]
132 def compute_all_at_once_residual(self, local_MS_running):
133 """
134 This requires to communicate the solutions at the end of the steps to be the initial conditions for the next
135 steps. Afterwards, the residual can be computed locally on the steps.
137 Args:
138 local_MS_running (list): list of currently running steps
139 """
141 for S in local_MS_running:
142 # communicate initial conditions
143 S.levels[0].sweep.compute_end_point()
145 for hook in self.hooks:
146 hook.pre_comm(step=S, level_number=0)
148 if not S.status.first:
149 S.levels[0].u[0] = S.prev.levels[0].uend
151 for hook in self.hooks:
152 hook.post_comm(step=S, level_number=0, add_to_stats=True)
154 # compute residuals locally
155 S.levels[0].sweep.compute_residual()
157 def update_solution(self, local_MS_running):
158 """
159 Since we solve for the increment, we need to update the solution between iterations by adding the increment.
161 Args:
162 local_MS_running (list): list of currently running steps
163 """
164 for S in local_MS_running:
165 for m in range(S.levels[0].sweep.coll.num_nodes):
166 S.levels[0].u[m + 1] += S.levels[0].increment[m]
168 def prepare_Jacobians(self, local_MS_running):
169 # get solutions for constructing average Jacobians
170 if self.params.average_jacobian:
171 level = local_MS_running[0].levels[0]
172 M = level.sweep.coll.num_nodes
174 u_avg = [level.prob.dtype_u(level.prob.init, val=0)] * M
176 # communicate average solution
177 for S in local_MS_running:
178 for m in range(M):
179 u_avg[m] += S.levels[0].u[m + 1] / self.n_steps
181 # store the averaged solution in the steps
182 for S in local_MS_running:
183 S.levels[0].u_avg = u_avg
185 def it_ParaDiag(self, local_MS_running):
186 """
187 Do a single ParaDiag iteration. Does the following steps
188 - (1) Compute the residual of the all-at-once / composite collocation problem
189 - (2) Compute an FFT in time to diagonalize the preconditioner
190 - (3) Solve the collocation problems locally on the steps for the increment
191 - (4) Compute iFFT in time to go back to the original base
192 - (5) Update the solution by adding increment
194 Note that this is the only place where we compute the all-at-once residual because it requires communication and
195 swaps the solution values for the residuals. So after the residual tolerance is reached, one more ParaDiag
196 iteration will be done.
198 Args:
199 local_MS_running (list): list of currently running steps
200 """
202 for S in local_MS_running:
203 for hook in self.hooks:
204 hook.pre_sweep(step=S, level_number=0)
206 # communicate average residual for setting up Jacobians for non-linear problems
207 self.prepare_Jacobians(local_MS_running)
209 # compute the all-at-once residual to use as right hand side
210 self.compute_all_at_once_residual(local_MS_running)
212 # weighted FFT of the residual in time
213 self.FFT_in_time(quantity='residual')
215 # perform local solves of "collocation problems" on the steps (can be done in parallel)
216 for S in local_MS_running:
217 assert len(S.levels) == 1, 'Multi-level SDC not implemented in ParaDiag'
218 S.levels[0].sweep.update_nodes()
220 # inverse FFT of the increment in time
221 self.iFFT_in_time(quantity='increment')
223 # get the next iterate by adding increment to previous iterate
224 self.update_solution(local_MS_running)
226 for S in local_MS_running:
227 for hook in self.hooks:
228 hook.post_sweep(step=S, level_number=0)
230 # update stage
231 for S in local_MS_running:
232 S.status.stage = 'IT_CHECK'
234 def it_check(self, local_MS_running):
235 """
236 Key routine to check for convergence/termination
238 Args:
239 local_MS_running (list): list of currently running steps
240 """
242 for S in local_MS_running:
243 if S.status.iter > 0:
244 for hook in self.hooks:
245 hook.post_iteration(step=S, level_number=0)
247 # decide if the step is done, needs to be restarted and other things convergence related
248 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
249 C.post_iteration_processing(self, S, MS=local_MS_running)
250 C.convergence_control(self, S, MS=local_MS_running)
252 for S in local_MS_running:
253 if not S.status.first:
254 for hook in self.hooks:
255 hook.pre_comm(step=S, level_number=0)
256 S.status.prev_done = S.prev.status.done # "communicate"
257 for hook in self.hooks:
258 hook.post_comm(step=S, level_number=0, add_to_stats=True)
259 S.status.done = S.status.done and S.status.prev_done
261 if self.params.all_to_done:
262 for hook in self.hooks:
263 hook.pre_comm(step=S, level_number=0)
264 S.status.done = all(T.status.done for T in local_MS_running)
265 for hook in self.hooks:
266 hook.post_comm(step=S, level_number=0, add_to_stats=True)
268 if not S.status.done:
269 # increment iteration count here (and only here)
270 S.status.iter += 1
271 for hook in self.hooks:
272 hook.pre_iteration(step=S, level_number=0)
273 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
274 C.pre_iteration_processing(self, S, MS=local_MS_running)
276 # Do another ParaDiag iteration
277 S.status.stage = 'IT_PARADIAG'
278 else:
279 S.levels[0].sweep.compute_end_point()
280 for hook in self.hooks:
281 hook.post_step(step=S, level_number=0)
282 S.status.stage = 'DONE'
284 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
285 C.reset_buffers_nonMPI(self)
287 def spread(self, local_MS_running):
288 """
289 Spreading phase
291 Args:
292 local_MS_running (list): list of currently running steps
293 """
295 for S in local_MS_running:
297 # first stage: spread values
298 for hook in self.hooks:
299 hook.pre_step(step=S, level_number=0)
301 # call predictor from sweeper
302 S.levels[0].sweep.predict()
304 # compute the residual
305 S.levels[0].sweep.compute_residual()
307 # update stage
308 S.status.stage = 'IT_CHECK'
310 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
311 C.post_spread_processing(self, S, MS=local_MS_running)
313 def run(self, u0, t0, Tend):
314 """
315 Main driver for running the serial version of ParaDiag
317 Args:
318 u0: initial values
319 t0: starting time
320 Tend: ending time
322 Returns:
323 end values on the last step
324 stats object containing statistics for each step, each level and each iteration
325 """
327 # some initializations and reset of statistics
328 uend = None
329 num_procs = len(self.MS)
330 for hook in self.hooks:
331 hook.reset_stats()
333 # initial ordering of the steps: 0,1,...,Np-1
334 slots = list(range(num_procs))
336 # initialize time variables of each step
337 time = [t0 + sum(self.MS[j].dt for j in range(p)) for p in slots]
339 # determine which steps are still active (time < Tend)
340 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots]
341 if not all(active) and any(active):
342 self.logger.warning(
343 'Warning: This controller will solve past your desired end time until the end of its block!'
344 )
345 active = [
346 True,
347 ] * len(active)
349 if not any(active):
350 raise ControllerError('Nothing to do, check t0, dt and Tend.')
352 # compress slots according to active steps, i.e. remove all steps which have times above Tend
353 active_slots = list(itertools.compress(slots, active))
355 # initialize block of steps with u0
356 self.restart_block(active_slots, time, u0)
358 for hook in self.hooks:
359 hook.post_setup(step=None, level_number=None)
361 # call pre-run hook
362 for S in self.MS:
363 for hook in self.hooks:
364 hook.pre_run(step=S, level_number=0)
366 # main loop: as long as at least one step is still active (time < Tend), do something
367 while any(active):
368 MS_active = [self.MS[p] for p in active_slots]
369 done = False
370 while not done:
371 done = self.ParaDiag(MS_active)
373 restarts = [S.status.restart for S in MS_active]
374 restart_at = np.where(restarts)[0][0] if True in restarts else len(MS_active)
375 if True in restarts: # restart part of the block
376 # initial condition to next block is initial condition of step that needs restarting
377 uend = self.MS[restart_at].levels[0].u[0]
378 time[active_slots[0]] = time[restart_at]
379 self.logger.info(f'Starting next block with initial conditions from step {restart_at}')
381 else: # move on to next block
382 # initial condition for next block is last solution of current block
383 uend = self.MS[active_slots[-1]].levels[0].uend
384 time[active_slots[0]] = time[active_slots[-1]] + self.MS[active_slots[-1]].dt
386 for S in MS_active[:restart_at]:
387 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
388 C.post_step_processing(self, S, MS=MS_active)
390 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
391 [C.prepare_next_block(self, S, len(active_slots), time, Tend, MS=MS_active) for S in self.MS]
393 # setup the times of the steps for the next block
394 for i in range(1, len(active_slots)):
395 time[active_slots[i]] = time[active_slots[i] - 1] + self.MS[active_slots[i] - 1].dt
397 # determine new set of active steps and compress slots accordingly
398 active = [time[p] < Tend - 10 * np.finfo(float).eps for p in slots]
399 if not all(active) and any(active):
400 self.logger.warning(
401 'Warning: This controller will solve past your desired end time until the end of its block!'
402 )
403 active = [
404 True,
405 ] * len(active)
406 active_slots = list(itertools.compress(slots, active))
408 # restart active steps (reset all values and pass uend to u0)
409 self.restart_block(active_slots, time, uend)
411 # call post-run hook
412 for S in self.MS:
413 for hook in self.hooks:
414 hook.post_run(step=S, level_number=0)
416 for S in self.MS:
417 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
418 C.post_run_processing(self, S, MS=MS_active)
420 return uend, self.return_stats()
422 def restart_block(self, active_slots, time, u0):
423 """
424 Helper routine to reset/restart block of (active) steps
426 Args:
427 active_slots: list of active steps
428 time: list of new times
429 u0: initial value to distribute across the steps
431 """
433 for j in range(len(active_slots)):
434 # get slot number
435 p = active_slots[j]
437 # store current slot number for diagnostics
438 self.MS[p].status.slot = p
439 # store link to previous step
440 self.MS[p].prev = self.MS[active_slots[j - 1]]
442 self.MS[p].reset_step()
444 # determine whether I am the first and/or last in line
445 self.MS[p].status.first = active_slots.index(p) == 0
446 self.MS[p].status.last = active_slots.index(p) == len(active_slots) - 1
448 # initialize step with u0
449 self.MS[p].init_step(u0)
451 # setup G^{-1} for new number of active slots
452 # self.MS[j].levels[0].sweep.set_G_inv(get_G_inv_matrix(j, len(active_slots), self.params.alpha, self.description['sweeper_params']))
454 # reset some values
455 self.MS[p].status.done = False
456 self.MS[p].status.prev_done = False
457 self.MS[p].status.iter = 0
458 self.MS[p].status.stage = 'SPREAD'
459 self.MS[p].status.force_done = False
460 self.MS[p].status.time_size = len(active_slots)
462 for l in self.MS[p].levels:
463 l.tag = None
464 l.status.sweep = 1
466 for p in active_slots:
467 for lvl in self.MS[p].levels:
468 lvl.status.time = time[p]
470 for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
471 C.reset_status_variables(self, active_slots=active_slots)