Coverage for pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py: 97%

74 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import numpy as np 

2 

3from pySDC.core.Lagrange import LagrangeApproximation 

4from pySDC.core.ConvergenceController import ConvergenceController, Status 

5from pySDC.core.Collocation import CollBase 

6 

7 

8class EstimatePolynomialError(ConvergenceController): 

9 """ 

10 Estimate the local error by using all but one collocation node in a polynomial interpolation to that node. 

11 While the converged collocation problem with M nodes gives a order M approximation to this point, the interpolation 

12 gives only an order M-1 approximation. Hence, we have two solutions with different order, and we know their order. 

13 That is to say this gives an error estimate that is order M. Keep in mind that the collocation problem should be 

14 converged for this and has order up to 2M. Still, the lower order method can be used for time step selection, for 

15 instance. 

16 If the last node is not the end point, we can interpolate to that node, which is an order M approximation and compare 

17 to the order 2M approximation we get from the extrapolation step. 

18 By default, we interpolate to the second to last node. 

19 """ 

20 

21 def setup(self, controller, params, description, **kwargs): 

22 """ 

23 Args: 

24 controller (pySDC.Controller.controller): The controller 

25 params (dict): The params passed for this specific convergence controller 

26 description (dict): The description object used to instantiate the controller 

27 

28 Returns: 

29 (dict): The updated params dictionary 

30 """ 

31 from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate 

32 from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

33 

34 sweeper_params = description['sweeper_params'] 

35 num_nodes = sweeper_params['num_nodes'] 

36 quad_type = sweeper_params['quad_type'] 

37 

38 defaults = { 

39 'control_order': -75, 

40 'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1, 

41 **super().setup(controller, params, description, **kwargs), 

42 } 

43 self.comm = description['sweeper_params'].get('comm', None) 

44 

45 if self.comm: 

46 from mpi4py import MPI 

47 

48 self.prepare_MPI_datatypes() 

49 self.MPI_SUM = MPI.SUM 

50 

51 controller.add_hook(LogEmbeddedErrorEstimate) 

52 self.check_convergence = CheckConvergence.check_convergence 

53 

54 if quad_type != 'GAUSS' and defaults['estimate_on_node'] > num_nodes: 

55 from pySDC.core.Errors import ParameterError 

56 

57 raise ParameterError( 

58 'You cannot interpolate with lower accuracy to the end point if the end point is a node!' 

59 ) 

60 

61 self.interpolation_matrix = None 

62 

63 return defaults 

64 

65 def reset_status_variables(self, controller, **kwargs): 

66 """ 

67 Add variable for embedded error 

68 

69 Args: 

70 controller (pySDC.Controller): The controller 

71 

72 Returns: 

73 None 

74 """ 

75 if 'comm' in kwargs.keys(): 

76 steps = [controller.S] 

77 else: 

78 if 'active_slots' in kwargs.keys(): 

79 steps = [controller.MS[i] for i in kwargs['active_slots']] 

80 else: 

81 steps = controller.MS 

82 

83 where = ["levels", "status"] 

84 for S in steps: 

85 self.add_variable(S, name='error_embedded_estimate', where=where, init=None) 

86 self.add_variable(S, name='order_embedded_estimate', where=where, init=None) 

87 

88 def matmul(self, A, b): 

89 """ 

90 Matrix vector multiplication, possibly MPI parallel. 

91 The parallel implementation performs a reduce operation in every row of the matrix. While communicating the 

92 entire vector once could reduce the number of communications, this way we never need to store the entire vector 

93 on any specific rank. 

94 

95 Args: 

96 A (2d np.ndarray): Matrix 

97 b (list): Vector 

98 

99 Returns: 

100 List: Axb 

101 """ 

102 if self.comm: 

103 res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])] 

104 buf = b[0] * 0.0 

105 for i in range(0, A.shape[0]): 

106 index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0) 

107 send_buf = ( 

108 (A[i, index] * b[index]) 

109 if self.comm.rank != self.params.estimate_on_node - 1 

110 else np.zeros_like(res[0]) 

111 ) 

112 self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM) 

113 res[i] += buf 

114 return res 

115 else: 

116 return A @ np.asarray(b) 

117 

118 def post_iteration_processing(self, controller, S, **kwargs): 

119 """ 

120 Estimate the error 

121 

122 Args: 

123 controller (pySDC.Controller.controller): The controller 

124 S (pySDC.Step.step): The current step 

125 

126 Returns: 

127 None 

128 """ 

129 

130 if self.check_convergence(S): 

131 L = S.levels[0] 

132 coll = L.sweep.coll 

133 nodes = np.append(np.append(0, coll.nodes), 1.0) 

134 estimate_on_node = self.params.estimate_on_node 

135 

136 if self.interpolation_matrix is None: 

137 interpolator = LagrangeApproximation( 

138 points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node] 

139 ) 

140 self.interpolation_matrix = interpolator.getInterpolationMatrix([nodes[estimate_on_node]]) 

141 

142 u = [ 

143 L.u[i].flatten() if L.u[i] is not None else L.u[i] 

144 for i in range(coll.num_nodes + 1) 

145 if i != estimate_on_node 

146 ] 

147 u_inter = self.matmul(self.interpolation_matrix, u)[0].reshape(L.prob.init[0]) 

148 

149 # compute end point if needed 

150 if estimate_on_node == len(nodes) - 1: 

151 if L.uend is None: 

152 L.sweep.compute_end_point() 

153 high_order_sol = L.uend 

154 rank = 0 

155 L.status.order_embedded_estimate = coll.num_nodes + 1 

156 else: 

157 high_order_sol = L.u[estimate_on_node] 

158 rank = estimate_on_node - 1 

159 L.status.order_embedded_estimate = coll.num_nodes * 1 

160 

161 if self.comm: 

162 buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0) 

163 self.comm.Bcast(buf, root=rank) 

164 L.status.error_embedded_estimate = buf 

165 else: 

166 L.status.error_embedded_estimate = abs(u_inter - high_order_sol) 

167 

168 self.debug( 

169 f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}', 

170 S, 

171 ) 

172 

173 def check_parameters(self, controller, params, description, **kwargs): 

174 """ 

175 Check if we allow the scheme to solve the collocation problems to convergence. 

176 

177 Args: 

178 controller (pySDC.Controller): The controller 

179 params (dict): The params passed for this specific convergence controller 

180 description (dict): The description object used to instantiate the controller 

181 

182 Returns: 

183 bool: Whether the parameters are compatible 

184 str: The error message 

185 """ 

186 if description['sweeper_params'].get('num_nodes', 0) < 2: 

187 return False, 'Need at least two collocation nodes to interpolate to one!' 

188 

189 return True, ""