Coverage for pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py: 97%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import numpy as np 

2 

3from qmat.lagrange import LagrangeApproximation 

4from pySDC.core.convergence_controller import ConvergenceController 

5 

6 

7class EstimatePolynomialError(ConvergenceController): 

8 """ 

9 Estimate the local error by using all but one collocation node in a polynomial interpolation to that node. 

10 While the converged collocation problem with M nodes gives a order M approximation to this point, the interpolation 

11 gives only an order M-1 approximation. Hence, we have two solutions with different order, and we know their order. 

12 That is to say this gives an error estimate that is order M. Keep in mind that the collocation problem should be 

13 converged for this and has order up to 2M. Still, the lower order method can be used for time step selection, for 

14 instance. 

15 If the last node is not the end point, we can interpolate to that node, which is an order M approximation and compare 

16 to the order 2M approximation we get from the extrapolation step. 

17 By default, we interpolate to the second to last node. 

18 """ 

19 

20 def setup(self, controller, params, description, **kwargs): 

21 """ 

22 Args: 

23 controller (pySDC.Controller.controller): The controller 

24 params (dict): The params passed for this specific convergence controller 

25 description (dict): The description object used to instantiate the controller 

26 

27 Returns: 

28 (dict): The updated params dictionary 

29 """ 

30 from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate 

31 from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

32 

33 sweeper_params = description['sweeper_params'] 

34 num_nodes = sweeper_params['num_nodes'] 

35 quad_type = sweeper_params['quad_type'] 

36 

37 defaults = { 

38 'control_order': -75, 

39 'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1, 

40 **super().setup(controller, params, description, **kwargs), 

41 } 

42 self.comm = description['sweeper_params'].get('comm', None) 

43 

44 if self.comm: 

45 from mpi4py import MPI 

46 

47 self.prepare_MPI_datatypes() 

48 self.MPI_SUM = MPI.SUM 

49 

50 controller.add_hook(LogEmbeddedErrorEstimate) 

51 self.check_convergence = CheckConvergence.check_convergence 

52 

53 if quad_type != 'GAUSS' and defaults['estimate_on_node'] > num_nodes: 

54 from pySDC.core.errors import ParameterError 

55 

56 raise ParameterError( 

57 'You cannot interpolate with lower accuracy to the end point if the end point is a node!' 

58 ) 

59 

60 self.interpolation_matrix = None 

61 

62 return defaults 

63 

64 def reset_status_variables(self, *args, **kwargs): 

65 """ 

66 Add variable for embedded error 

67 

68 Returns: 

69 None 

70 """ 

71 self.add_status_variable_to_level('error_embedded_estimate') 

72 self.add_status_variable_to_level('order_embedded_estimate') 

73 

74 def matmul(self, A, b, xp=np): 

75 """ 

76 Matrix vector multiplication, possibly MPI parallel. 

77 The parallel implementation performs a reduce operation in every row of the matrix. While communicating the 

78 entire vector once could reduce the number of communications, this way we never need to store the entire vector 

79 on any specific rank. 

80 

81 Args: 

82 A (2d np.ndarray): Matrix 

83 b (list): Vector 

84 xp: Either numpy or cupy 

85 

86 Returns: 

87 List: Axb 

88 """ 

89 

90 if self.comm: 

91 res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])] 

92 buf = b[0] * 0.0 

93 for i in range(0, A.shape[0]): 

94 index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0) 

95 send_buf = ( 

96 (A[i, index] * b[index]) 

97 if self.comm.rank != self.params.estimate_on_node - 1 

98 else xp.zeros_like(res[0]) 

99 ) 

100 self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM) 

101 res[i] += buf 

102 return res 

103 else: 

104 return A @ xp.asarray(b) 

105 

106 def post_iteration_processing(self, controller, S, **kwargs): 

107 """ 

108 Estimate the error 

109 

110 Args: 

111 controller (pySDC.Controller.controller): The controller 

112 S (pySDC.Step.step): The current step 

113 

114 Returns: 

115 None 

116 """ 

117 

118 if self.check_convergence(S): 

119 L = S.levels[0] 

120 coll = L.sweep.coll 

121 nodes = np.append(np.append(0, coll.nodes), 1.0) 

122 estimate_on_node = self.params.estimate_on_node 

123 xp = L.u[0].xp 

124 

125 if self.interpolation_matrix is None: 

126 interpolator = LagrangeApproximation( 

127 points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node] 

128 ) 

129 self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]])) 

130 

131 u = [ 

132 L.u[i].flatten() if L.u[i] is not None else L.u[i] 

133 for i in range(coll.num_nodes + 1) 

134 if i != estimate_on_node 

135 ] 

136 u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0]) 

137 

138 # compute end point if needed 

139 if estimate_on_node == len(nodes) - 1: 

140 if L.uend is None: 

141 L.sweep.compute_end_point() 

142 high_order_sol = L.uend 

143 rank = 0 

144 L.status.order_embedded_estimate = coll.num_nodes + 1 

145 else: 

146 high_order_sol = L.u[estimate_on_node] 

147 rank = estimate_on_node - 1 

148 L.status.order_embedded_estimate = coll.num_nodes * 1 

149 

150 if self.comm: 

151 buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0) 

152 self.comm.Bcast(buf, root=rank) 

153 L.status.error_embedded_estimate = float(buf) 

154 else: 

155 L.status.error_embedded_estimate = abs(u_inter - high_order_sol) 

156 

157 self.debug( 

158 f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}', 

159 S, 

160 ) 

161 

162 def check_parameters(self, controller, params, description, **kwargs): 

163 """ 

164 Check if we allow the scheme to solve the collocation problems to convergence. 

165 

166 Args: 

167 controller (pySDC.Controller): The controller 

168 params (dict): The params passed for this specific convergence controller 

169 description (dict): The description object used to instantiate the controller 

170 

171 Returns: 

172 bool: Whether the parameters are compatible 

173 str: The error message 

174 """ 

175 if description['sweeper_params'].get('num_nodes', 0) < 2: 

176 return False, 'Need at least two collocation nodes to interpolate to one!' 

177 

178 return True, ""