Coverage for pySDC/implementations/convergence_controller_classes/interpolate_between_restarts.py: 95%

41 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import numpy as np 

2from pySDC.core.ConvergenceController import ConvergenceController, Status 

3from pySDC.core.Lagrange import LagrangeApproximation 

4 

5 

6class InterpolateBetweenRestarts(ConvergenceController): 

7 """ 

8 Interpolate the solution and right hand side to the new set of collocation nodes after a restart. 

9 The idea is that when you adjust the step size between restarts, you already know what the new quadrature method 

10 is going to be and possibly interpolating the current iterate to these results in a better initial guess than 

11 spreading the initial conditions or whatever you usually like to do. 

12 """ 

13 

14 def setup(self, controller, params, description, **kwargs): 

15 """ 

16 Store the initial guess used in the sweeper when no restart has happened 

17 

18 Args: 

19 controller (pySDC.Controller.controller): The controller 

20 params (dict): Parameters for the convergence controller 

21 description (dict): The description object used to instantiate the controller 

22 """ 

23 defaults = { 

24 'control_order': 50, 

25 } 

26 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

27 

28 def setup_status_variables(self, controller, **kwargs): 

29 """ 

30 Add variables to the sweeper containing the interpolated solution and right hand side. 

31 

32 Args: 

33 controller (pySDC.Controller.controller): The controller 

34 """ 

35 self.status = Status(['u_inter', 'f_inter', 'perform_interpolation', 'skip_interpolation']) 

36 

37 self.status.u_inter = [] 

38 self.status.f_inter = [] 

39 self.status.perform_interpolation = False 

40 self.status.skip_interpolation = False 

41 

42 def post_spread_processing(self, controller, step, **kwargs): 

43 """ 

44 Spread the interpolated values to the collocation nodes. This overrides whatever the sweeper uses for prediction. 

45 

46 Args: 

47 controller (pySDC.Controller.controller): The controller 

48 step (pySDC.Step.step): The current step 

49 """ 

50 if self.status.perform_interpolation and not self.status.skip_interpolation: 

51 for i in range(len(step.levels)): 

52 level = step.levels[i] 

53 

54 if level.f[0] is None: 

55 level.f[0] = level.prob.dtype_f(level.prob.init) 

56 

57 for m in range(len(level.u)): 

58 level.u[m][:] = self.status.u_inter[i][m].reshape(level.prob.init[0])[:] 

59 level.f[m][:] = self.status.f_inter[i][m].reshape(level.prob.init[0])[:] 

60 

61 # reset the status variables 

62 self.status.perform_interpolation = False 

63 self.status.u_inter = [] 

64 self.status.f_inter = [] 

65 

66 self.status.skip_interpolation = False 

67 

68 def post_iteration_processing(self, controller, step, **kwargs): 

69 """ 

70 Interpolate the solution and right hand sides and store them in the sweeper, where they will be distributed 

71 accordingly in the prediction step. 

72 

73 This function is called after every iteration instead of just after the step because we might choose to stop 

74 iterating as soon as we have decided to restart. If we let the step continue to iterate, this is not the most 

75 efficient implementation and you may choose to write a different convergence controller. 

76 

77 The interpolation is based on Thibaut's magic. 

78 

79 Args: 

80 controller (pySDC.Controller): The controller 

81 step (pySDC.Step.step): The current step 

82 """ 

83 if ( 

84 step.status.restart 

85 and all(level.status.dt_new for level in step.levels) 

86 and not self.status.skip_interpolation 

87 ): 

88 for level in step.levels: 

89 nodes_old = level.sweep.coll.nodes.copy() 

90 nodes_new = level.sweep.coll.nodes.copy() * level.status.dt_new / level.params.dt 

91 

92 if level.f[0] is None: 

93 prob = level.prob 

94 level.f[0] = prob.eval_f(level.u[0], level.time) 

95 

96 interpolator = LagrangeApproximation(points=np.append(0, nodes_old)) 

97 interpolation_matrix = interpolator.getInterpolationMatrix(np.append(0, nodes_new)) 

98 self.status.u_inter += [(interpolation_matrix @ [me.flatten() for me in level.u][:])[:]] 

99 self.status.f_inter += [(interpolation_matrix @ [me.flatten() for me in level.f][:])[:]] 

100 self.status.perform_interpolation = True 

101 

102 self.log( 

103 f'Interpolating before restart from dt={level.params.dt:.2e} to dt={level.status.dt_new:.2e}', step 

104 ) 

105 

106 else: 

107 self.status.perform_interpolation = False