Coverage for pySDC/implementations/convergence_controller_classes/check_convergence.py: 84%
57 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import numpy as np
3from pySDC.core.convergence_controller import ConvergenceController
6class CheckConvergence(ConvergenceController):
7 """
8 Perform simple checks on convergence for SDC iterations.
10 Iteration is terminated via one of two criteria:
11 - Residual tolerance
12 - Maximum number of iterations
13 """
15 def setup(self, controller, params, description, **kwargs):
16 """
17 Define default parameters here
19 Args:
20 controller (pySDC.Controller): The controller
21 params (dict): The params passed for this specific convergence controller
22 description (dict): The description object used to instantiate the controller
24 Returns:
25 (dict): The updated params dictionary
26 """
27 defaults = {'control_order': +200, 'use_e_tol': 'e_tol' in description['level_params'].keys()}
29 return {**defaults, **super().setup(controller, params, description, **kwargs)}
31 def dependencies(self, controller, description, **kwargs):
32 """
33 Load the embedded error estimator if needed.
35 Args:
36 controller (pySDC.Controller): The controller
37 description (dict): The description object used to instantiate the controller
39 Returns:
40 None
41 """
42 if self.params.useMPI:
43 self.prepare_MPI_logical_operations()
45 super().dependencies(controller, description)
47 if self.params.use_e_tol:
48 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
49 EstimateEmbeddedError,
50 )
52 controller.add_convergence_controller(
53 EstimateEmbeddedError,
54 description=description,
55 )
57 return None
59 @staticmethod
60 def check_convergence(S, self=None):
61 """
62 Check the convergence of a single step.
63 Test the residual and max. number of iterations as well as allowing overrides to both stop and continue.
65 Args:
66 S (pySDC.Step): The current step
68 Returns:
69 bool: Convergence status of the step
70 """
71 # do all this on the finest level
72 L = S.levels[0]
74 # get residual and check against prescribed tolerance (plus check number of iterations)
75 iter_converged = S.status.iter >= S.params.maxiter
76 res_converged = L.status.residual <= L.params.restol
77 e_tol_converged = (
78 L.status.error_embedded_estimate < L.params.e_tol
79 if (L.params.get('e_tol') and L.status.get('error_embedded_estimate'))
80 else False
81 )
82 converged = (
83 iter_converged or res_converged or e_tol_converged or S.status.force_done
84 ) and not S.status.force_continue
85 if converged is None:
86 converged = False
88 # print information for debugging
89 if converged and self:
90 self.debug(
91 f'Declared convergence: maxiter reached[{"x" if iter_converged else " "}] restol reached[{"x" if res_converged else " "}] e_tol reached[{"x" if e_tol_converged else " "}]',
92 S,
93 )
94 return converged
96 def check_iteration_status(self, controller, S, **kwargs):
97 """
98 Routine to determine whether to stop iterating (currently testing the residual + the max. number of iterations)
100 Args:
101 controller (pySDC.Controller.controller): The controller
102 S (pySDC.Step.step): The current step
104 Returns:
105 None
106 """
107 S.status.done = self.check_convergence(S, self)
109 if "comm" in kwargs.keys():
110 self.communicate_convergence(controller, S, **kwargs)
112 S.status.force_continue = False
114 return None
116 def communicate_convergence(self, controller, S, comm):
117 """
118 Communicate the convergence status during `check_iteration_status` if MPI is used.
120 Args:
121 controller (pySDC.Controller): The controller
122 S (pySDC.Step.step): The current step
123 comm (mpi4py.MPI.Comm): MPI communicator
125 Returns:
126 None
127 """
128 # Either gather information about all status or send forward own
129 if controller.params.all_to_done:
130 for hook in controller.hooks:
131 hook.pre_comm(step=S, level_number=0)
132 S.status.done = comm.allreduce(sendobj=S.status.done, op=self.MPI_LAND)
133 S.status.force_done = comm.allreduce(sendobj=S.status.force_done, op=self.MPI_LOR)
134 for hook in controller.hooks:
135 hook.post_comm(step=S, level_number=0, add_to_stats=True)
137 S.status.done = S.status.done or S.status.force_done
139 else:
140 for hook in controller.hooks:
141 hook.pre_comm(step=S, level_number=0)
143 # check if an open request of the status send is pending
144 controller.wait_with_interrupt(request=controller.req_status)
145 if S.status.force_done:
146 return None
148 # recv status
149 if not S.status.first and not S.status.prev_done:
150 buff = np.empty(1, dtype=bool)
151 self.Recv(comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
152 S.status.prev_done = buff[0]
153 S.status.done = S.status.done and S.status.prev_done
155 # send status forward
156 if not S.status.last:
157 buff = np.empty(1, dtype=bool)
158 buff[0] = S.status.done
159 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])
161 for hook in controller.hooks:
162 hook.post_comm(step=S, level_number=0, add_to_stats=True)