Coverage for pySDC/implementations/convergence_controller_classes/interpolate_between_restarts.py: 93%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
2from pySDC.core.convergence_controller import ConvergenceController, Status
3from qmat.lagrange import LagrangeApproximation
6class InterpolateBetweenRestarts(ConvergenceController):
7 """
8 Interpolate the solution and right hand side to the new set of collocation nodes after a restart.
9 The idea is that when you adjust the step size between restarts, you already know what the new quadrature method
10 is going to be and possibly interpolating the current iterate to these results in a better initial guess than
11 spreading the initial conditions or whatever you usually like to do.
12 """
14 def setup(self, controller, params, description, **kwargs):
15 """
16 Store the initial guess used in the sweeper when no restart has happened
18 Args:
19 controller (pySDC.Controller.controller): The controller
20 params (dict): Parameters for the convergence controller
21 description (dict): The description object used to instantiate the controller
22 """
23 defaults = {
24 'control_order': 50,
25 }
26 return {**defaults, **super().setup(controller, params, description, **kwargs)}
28 def setup_status_variables(self, controller, **kwargs):
29 """
30 Add variables to the sweeper containing the interpolated solution and right hand side.
32 Args:
33 controller (pySDC.Controller.controller): The controller
34 """
35 self.status = Status(['u_inter', 'f_inter', 'perform_interpolation', 'skip_interpolation'])
37 self.status.u_inter = []
38 self.status.f_inter = []
39 self.status.perform_interpolation = False
40 self.status.skip_interpolation = False
42 def post_spread_processing(self, controller, step, **kwargs):
43 """
44 Spread the interpolated values to the collocation nodes. This overrides whatever the sweeper uses for prediction.
46 Args:
47 controller (pySDC.Controller.controller): The controller
48 step (pySDC.Step.step): The current step
49 """
50 if self.status.perform_interpolation and not self.status.skip_interpolation:
51 for i in range(len(step.levels)):
52 level = step.levels[i]
54 if level.f[0] is None:
55 level.f[0] = level.prob.dtype_f(level.prob.init)
57 for m in range(len(level.u)):
58 level.u[m][:] = self.status.u_inter[i][m].reshape(level.prob.init[0])[:]
59 level.f[m][:] = self.status.f_inter[i][m].reshape(level.f[m].shape)[:]
61 # reset the status variables
62 self.status.perform_interpolation = False
63 self.status.u_inter = []
64 self.status.f_inter = []
66 self.status.skip_interpolation = False
68 def post_iteration_processing(self, controller, step, **kwargs):
69 """
70 Interpolate the solution and right hand sides and store them in the sweeper, where they will be distributed
71 accordingly in the prediction step.
73 This function is called after every iteration instead of just after the step because we might choose to stop
74 iterating as soon as we have decided to restart. If we let the step continue to iterate, this is not the most
75 efficient implementation and you may choose to write a different convergence controller.
77 The interpolation is based on Thibaut's magic.
79 Args:
80 controller (pySDC.Controller): The controller
81 step (pySDC.Step.step): The current step
82 """
83 if (
84 step.status.restart
85 and all(level.status.dt_new for level in step.levels)
86 and not self.status.skip_interpolation
87 ):
88 for level in step.levels:
89 nodes_old = level.sweep.coll.nodes.copy()
90 nodes_new = level.sweep.coll.nodes.copy() * level.status.dt_new / level.params.dt
92 if level.f[0] is None:
93 prob = level.prob
94 level.f[0] = prob.eval_f(level.u[0], level.time)
96 interpolator = LagrangeApproximation(points=np.append(0, nodes_old))
97 interpolation_matrix = interpolator.getInterpolationMatrix(np.append(0, nodes_new))
98 self.status.u_inter += [(interpolation_matrix @ [me.flatten() for me in level.u][:])[:]]
99 self.status.f_inter += [(interpolation_matrix @ [me.flatten() for me in level.f][:])[:]]
100 self.status.perform_interpolation = True
102 self.log(
103 f'Interpolating before restart from dt={level.params.dt:.2e} to dt={level.status.dt_new:.2e}', step
104 )
106 else:
107 self.status.perform_interpolation = False