# Coverage for pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py: 97%

## 67 statements

, created at 2024-09-09 14:59 +0000

1import numpy as np

3from qmat.lagrange import LagrangeApproximation

4from pySDC.core.convergence_controller import ConvergenceController

7class EstimatePolynomialError(ConvergenceController):

8 """

9 Estimate the local error by using all but one collocation node in a polynomial interpolation to that node.

10 While the converged collocation problem with M nodes gives a order M approximation to this point, the interpolation

11 gives only an order M-1 approximation. Hence, we have two solutions with different order, and we know their order.

12 That is to say this gives an error estimate that is order M. Keep in mind that the collocation problem should be

13 converged for this and has order up to 2M. Still, the lower order method can be used for time step selection, for

14 instance.

15 If the last node is not the end point, we can interpolate to that node, which is an order M approximation and compare

16 to the order 2M approximation we get from the extrapolation step.

17 By default, we interpolate to the second to last node.

18 """

20 def setup(self, controller, params, description, **kwargs):

21 """

22 Args:

23 controller (pySDC.Controller.controller): The controller

24 params (dict): The params passed for this specific convergence controller

25 description (dict): The description object used to instantiate the controller

27 Returns:

28 (dict): The updated params dictionary

29 """

30 from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate

31 from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence

33 sweeper_params = description['sweeper_params']

34 num_nodes = sweeper_params['num_nodes']

37 defaults = {

38 'control_order': -75,

39 'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1,

40 **super().setup(controller, params, description, **kwargs),

41 }

42 self.comm = description['sweeper_params'].get('comm', None)

44 if self.comm:

45 from mpi4py import MPI

47 self.prepare_MPI_datatypes()

48 self.MPI_SUM = MPI.SUM

51 self.check_convergence = CheckConvergence.check_convergence

53 if quad_type != 'GAUSS' and defaults['estimate_on_node'] > num_nodes:

54 from pySDC.core.errors import ParameterError

56 raise ParameterError(

57 'You cannot interpolate with lower accuracy to the end point if the end point is a node!'

58 )

60 self.interpolation_matrix = None

62 return defaults

64 def reset_status_variables(self, *args, **kwargs):

65 """

66 Add variable for embedded error

68 Returns:

69 None

70 """

74 def matmul(self, A, b, xp=np):

75 """

76 Matrix vector multiplication, possibly MPI parallel.

77 The parallel implementation performs a reduce operation in every row of the matrix. While communicating the

78 entire vector once could reduce the number of communications, this way we never need to store the entire vector

79 on any specific rank.

81 Args:

82 A (2d np.ndarray): Matrix

83 b (list): Vector

84 xp: Either numpy or cupy

86 Returns:

87 List: Axb

88 """

90 if self.comm:

91 res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]

92 buf = b[0] * 0.0

93 for i in range(0, A.shape[0]):

94 index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0)

95 send_buf = (

96 (A[i, index] * b[index])

97 if self.comm.rank != self.params.estimate_on_node - 1

98 else xp.zeros_like(res[0])

99 )

100 self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)

101 res[i] += buf

102 return res

103 else:

104 return A @ xp.asarray(b)

106 def post_iteration_processing(self, controller, S, **kwargs):

107 """

108 Estimate the error

110 Args:

111 controller (pySDC.Controller.controller): The controller

112 S (pySDC.Step.step): The current step

114 Returns:

115 None

116 """

118 if self.check_convergence(S):

119 L = S.levels[0]

120 coll = L.sweep.coll

121 nodes = np.append(np.append(0, coll.nodes), 1.0)

122 estimate_on_node = self.params.estimate_on_node

123 xp = L.u[0].xp

125 if self.interpolation_matrix is None:

126 interpolator = LagrangeApproximation(

127 points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]

128 )

129 self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))

131 u = [

132 L.u[i].flatten() if L.u[i] is not None else L.u[i]

133 for i in range(coll.num_nodes + 1)

134 if i != estimate_on_node

135 ]

136 u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])

138 # compute end point if needed

139 if estimate_on_node == len(nodes) - 1:

140 if L.uend is None:

141 L.sweep.compute_end_point()

142 high_order_sol = L.uend

143 rank = 0

144 L.status.order_embedded_estimate = coll.num_nodes + 1

145 else:

146 high_order_sol = L.u[estimate_on_node]

147 rank = estimate_on_node - 1

148 L.status.order_embedded_estimate = coll.num_nodes * 1

150 if self.comm:

151 buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)

152 self.comm.Bcast(buf, root=rank)

153 L.status.error_embedded_estimate = buf

154 else:

155 L.status.error_embedded_estimate = abs(u_inter - high_order_sol)

157 self.debug(

158 f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',

159 S,

160 )

162 def check_parameters(self, controller, params, description, **kwargs):

163 """

164 Check if we allow the scheme to solve the collocation problems to convergence.

166 Args:

167 controller (pySDC.Controller): The controller

168 params (dict): The params passed for this specific convergence controller

169 description (dict): The description object used to instantiate the controller

171 Returns:

172 bool: Whether the parameters are compatible

173 str: The error message

174 """

175 if description['sweeper_params'].get('num_nodes', 0) < 2:

176 return False, 'Need at least two collocation nodes to interpolate to one!'

178 return True, ""