Coverage for pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py: 97%
67 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import numpy as np
3from qmat.lagrange import LagrangeApproximation
4from pySDC.core.convergence_controller import ConvergenceController
7class EstimatePolynomialError(ConvergenceController):
8 """
9 Estimate the local error by using all but one collocation node in a polynomial interpolation to that node.
10 While the converged collocation problem with M nodes gives a order M approximation to this point, the interpolation
11 gives only an order M-1 approximation. Hence, we have two solutions with different order, and we know their order.
12 That is to say this gives an error estimate that is order M. Keep in mind that the collocation problem should be
13 converged for this and has order up to 2M. Still, the lower order method can be used for time step selection, for
14 instance.
15 If the last node is not the end point, we can interpolate to that node, which is an order M approximation and compare
16 to the order 2M approximation we get from the extrapolation step.
17 By default, we interpolate to the second to last node.
18 """
20 def setup(self, controller, params, description, **kwargs):
21 """
22 Args:
23 controller (pySDC.Controller.controller): The controller
24 params (dict): The params passed for this specific convergence controller
25 description (dict): The description object used to instantiate the controller
27 Returns:
28 (dict): The updated params dictionary
29 """
30 from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
31 from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
33 sweeper_params = description['sweeper_params']
34 num_nodes = sweeper_params['num_nodes']
35 quad_type = sweeper_params['quad_type']
37 defaults = {
38 'control_order': -75,
39 'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1,
40 **super().setup(controller, params, description, **kwargs),
41 }
42 self.comm = description['sweeper_params'].get('comm', None)
44 if self.comm:
45 from mpi4py import MPI
47 self.prepare_MPI_datatypes()
48 self.MPI_SUM = MPI.SUM
50 controller.add_hook(LogEmbeddedErrorEstimate)
51 self.check_convergence = CheckConvergence.check_convergence
53 if quad_type != 'GAUSS' and defaults['estimate_on_node'] > num_nodes:
54 from pySDC.core.errors import ParameterError
56 raise ParameterError(
57 'You cannot interpolate with lower accuracy to the end point if the end point is a node!'
58 )
60 self.interpolation_matrix = None
62 return defaults
64 def reset_status_variables(self, *args, **kwargs):
65 """
66 Add variable for embedded error
68 Returns:
69 None
70 """
71 self.add_status_variable_to_level('error_embedded_estimate')
72 self.add_status_variable_to_level('order_embedded_estimate')
74 def matmul(self, A, b, xp=np):
75 """
76 Matrix vector multiplication, possibly MPI parallel.
77 The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
78 entire vector once could reduce the number of communications, this way we never need to store the entire vector
79 on any specific rank.
81 Args:
82 A (2d np.ndarray): Matrix
83 b (list): Vector
84 xp: Either numpy or cupy
86 Returns:
87 List: Axb
88 """
90 if self.comm:
91 res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
92 buf = b[0] * 0.0
93 for i in range(0, A.shape[0]):
94 index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0)
95 send_buf = (
96 (A[i, index] * b[index])
97 if self.comm.rank != self.params.estimate_on_node - 1
98 else xp.zeros_like(res[0])
99 )
100 self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
101 res[i] += buf
102 return res
103 else:
104 return A @ xp.asarray(b)
106 def post_iteration_processing(self, controller, S, **kwargs):
107 """
108 Estimate the error
110 Args:
111 controller (pySDC.Controller.controller): The controller
112 S (pySDC.Step.step): The current step
114 Returns:
115 None
116 """
118 if self.check_convergence(S):
119 L = S.levels[0]
120 coll = L.sweep.coll
121 nodes = np.append(np.append(0, coll.nodes), 1.0)
122 estimate_on_node = self.params.estimate_on_node
123 xp = L.u[0].xp
125 if self.interpolation_matrix is None:
126 interpolator = LagrangeApproximation(
127 points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
128 )
129 self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
131 u = [
132 L.u[i].flatten() if L.u[i] is not None else L.u[i]
133 for i in range(coll.num_nodes + 1)
134 if i != estimate_on_node
135 ]
136 u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
138 # compute end point if needed
139 if estimate_on_node == len(nodes) - 1:
140 if L.uend is None:
141 L.sweep.compute_end_point()
142 high_order_sol = L.uend
143 rank = 0
144 L.status.order_embedded_estimate = coll.num_nodes + 1
145 else:
146 high_order_sol = L.u[estimate_on_node]
147 rank = estimate_on_node - 1
148 L.status.order_embedded_estimate = coll.num_nodes * 1
150 if self.comm:
151 buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
152 self.comm.Bcast(buf, root=rank)
153 L.status.error_embedded_estimate = float(buf)
154 else:
155 L.status.error_embedded_estimate = abs(u_inter - high_order_sol)
157 self.debug(
158 f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',
159 S,
160 )
162 def check_parameters(self, controller, params, description, **kwargs):
163 """
164 Check if we allow the scheme to solve the collocation problems to convergence.
166 Args:
167 controller (pySDC.Controller): The controller
168 params (dict): The params passed for this specific convergence controller
169 description (dict): The description object used to instantiate the controller
171 Returns:
172 bool: Whether the parameters are compatible
173 str: The error message
174 """
175 if description['sweeper_params'].get('num_nodes', 0) < 2:
176 return False, 'Need at least two collocation nodes to interpolate to one!'
178 return True, ""