Coverage for pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py: 97%
74 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1import numpy as np
3from pySDC.core.Lagrange import LagrangeApproximation
4from pySDC.core.ConvergenceController import ConvergenceController, Status
5from pySDC.core.Collocation import CollBase
8class EstimatePolynomialError(ConvergenceController):
9 """
10 Estimate the local error by using all but one collocation node in a polynomial interpolation to that node.
11 While the converged collocation problem with M nodes gives a order M approximation to this point, the interpolation
12 gives only an order M-1 approximation. Hence, we have two solutions with different order, and we know their order.
13 That is to say this gives an error estimate that is order M. Keep in mind that the collocation problem should be
14 converged for this and has order up to 2M. Still, the lower order method can be used for time step selection, for
15 instance.
16 If the last node is not the end point, we can interpolate to that node, which is an order M approximation and compare
17 to the order 2M approximation we get from the extrapolation step.
18 By default, we interpolate to the second to last node.
19 """
21 def setup(self, controller, params, description, **kwargs):
22 """
23 Args:
24 controller (pySDC.Controller.controller): The controller
25 params (dict): The params passed for this specific convergence controller
26 description (dict): The description object used to instantiate the controller
28 Returns:
29 (dict): The updated params dictionary
30 """
31 from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
32 from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
34 sweeper_params = description['sweeper_params']
35 num_nodes = sweeper_params['num_nodes']
36 quad_type = sweeper_params['quad_type']
38 defaults = {
39 'control_order': -75,
40 'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1,
41 **super().setup(controller, params, description, **kwargs),
42 }
43 self.comm = description['sweeper_params'].get('comm', None)
45 if self.comm:
46 from mpi4py import MPI
48 self.prepare_MPI_datatypes()
49 self.MPI_SUM = MPI.SUM
51 controller.add_hook(LogEmbeddedErrorEstimate)
52 self.check_convergence = CheckConvergence.check_convergence
54 if quad_type != 'GAUSS' and defaults['estimate_on_node'] > num_nodes:
55 from pySDC.core.Errors import ParameterError
57 raise ParameterError(
58 'You cannot interpolate with lower accuracy to the end point if the end point is a node!'
59 )
61 self.interpolation_matrix = None
63 return defaults
65 def reset_status_variables(self, controller, **kwargs):
66 """
67 Add variable for embedded error
69 Args:
70 controller (pySDC.Controller): The controller
72 Returns:
73 None
74 """
75 if 'comm' in kwargs.keys():
76 steps = [controller.S]
77 else:
78 if 'active_slots' in kwargs.keys():
79 steps = [controller.MS[i] for i in kwargs['active_slots']]
80 else:
81 steps = controller.MS
83 where = ["levels", "status"]
84 for S in steps:
85 self.add_variable(S, name='error_embedded_estimate', where=where, init=None)
86 self.add_variable(S, name='order_embedded_estimate', where=where, init=None)
88 def matmul(self, A, b):
89 """
90 Matrix vector multiplication, possibly MPI parallel.
91 The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
92 entire vector once could reduce the number of communications, this way we never need to store the entire vector
93 on any specific rank.
95 Args:
96 A (2d np.ndarray): Matrix
97 b (list): Vector
99 Returns:
100 List: Axb
101 """
102 if self.comm:
103 res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
104 buf = b[0] * 0.0
105 for i in range(0, A.shape[0]):
106 index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0)
107 send_buf = (
108 (A[i, index] * b[index])
109 if self.comm.rank != self.params.estimate_on_node - 1
110 else np.zeros_like(res[0])
111 )
112 self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
113 res[i] += buf
114 return res
115 else:
116 return A @ np.asarray(b)
118 def post_iteration_processing(self, controller, S, **kwargs):
119 """
120 Estimate the error
122 Args:
123 controller (pySDC.Controller.controller): The controller
124 S (pySDC.Step.step): The current step
126 Returns:
127 None
128 """
130 if self.check_convergence(S):
131 L = S.levels[0]
132 coll = L.sweep.coll
133 nodes = np.append(np.append(0, coll.nodes), 1.0)
134 estimate_on_node = self.params.estimate_on_node
136 if self.interpolation_matrix is None:
137 interpolator = LagrangeApproximation(
138 points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
139 )
140 self.interpolation_matrix = interpolator.getInterpolationMatrix([nodes[estimate_on_node]])
142 u = [
143 L.u[i].flatten() if L.u[i] is not None else L.u[i]
144 for i in range(coll.num_nodes + 1)
145 if i != estimate_on_node
146 ]
147 u_inter = self.matmul(self.interpolation_matrix, u)[0].reshape(L.prob.init[0])
149 # compute end point if needed
150 if estimate_on_node == len(nodes) - 1:
151 if L.uend is None:
152 L.sweep.compute_end_point()
153 high_order_sol = L.uend
154 rank = 0
155 L.status.order_embedded_estimate = coll.num_nodes + 1
156 else:
157 high_order_sol = L.u[estimate_on_node]
158 rank = estimate_on_node - 1
159 L.status.order_embedded_estimate = coll.num_nodes * 1
161 if self.comm:
162 buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
163 self.comm.Bcast(buf, root=rank)
164 L.status.error_embedded_estimate = buf
165 else:
166 L.status.error_embedded_estimate = abs(u_inter - high_order_sol)
168 self.debug(
169 f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',
170 S,
171 )
173 def check_parameters(self, controller, params, description, **kwargs):
174 """
175 Check if we allow the scheme to solve the collocation problems to convergence.
177 Args:
178 controller (pySDC.Controller): The controller
179 params (dict): The params passed for this specific convergence controller
180 description (dict): The description object used to instantiate the controller
182 Returns:
183 bool: Whether the parameters are compatible
184 str: The error message
185 """
186 if description['sweeper_params'].get('num_nodes', 0) < 2:
187 return False, 'Need at least two collocation nodes to interpolate to one!'
189 return True, ""