Coverage for pySDC/implementations/convergence_controller_classes/check_convergence.py: 84%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
3from pySDC.core.convergence_controller import ConvergenceController
6class CheckConvergence(ConvergenceController):
7 """
8 Perform simple checks on convergence for SDC iterations.
10 Iteration is terminated via one of two criteria:
11 - Residual tolerance
12 - Maximum number of iterations
13 """
15 def setup(self, controller, params, description, **kwargs):
16 """
17 Define default parameters here
19 Args:
20 controller (pySDC.Controller): The controller
21 params (dict): The params passed for this specific convergence controller
22 description (dict): The description object used to instantiate the controller
24 Returns:
25 (dict): The updated params dictionary
26 """
27 defaults = {'control_order': +200, 'use_e_tol': 'e_tol' in description['level_params'].keys()}
29 return {**defaults, **super().setup(controller, params, description, **kwargs)}
31 def dependencies(self, controller, description, **kwargs):
32 """
33 Load the embedded error estimator if needed.
35 Args:
36 controller (pySDC.Controller): The controller
37 description (dict): The description object used to instantiate the controller
39 Returns:
40 None
41 """
42 if self.params.useMPI:
43 self.prepare_MPI_logical_operations()
45 super().dependencies(controller, description)
47 if self.params.use_e_tol:
48 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
49 EstimateEmbeddedError,
50 )
52 controller.add_convergence_controller(
53 EstimateEmbeddedError,
54 description=description,
55 )
57 return None
59 @staticmethod
60 def check_convergence(S, self=None):
61 """
62 Check the convergence of a single step.
63 Test the residual and max. number of iterations as well as allowing overrides to both stop and continue.
65 Args:
66 S (pySDC.Step): The current step
68 Returns:
69 bool: Convergence status of the step
70 """
71 # do all this on the finest level
72 L = S.levels[0]
74 # get residual and check against prescribed tolerance (plus check number of iterations)
75 iter_converged = S.status.iter >= S.params.maxiter
76 res_converged = L.status.residual <= L.params.restol and (S.status.iter > 0 or L.status.sweep > 0)
77 e_tol_converged = (
78 L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False
79 )
80 converged = (
81 iter_converged or res_converged or e_tol_converged or S.status.force_done
82 ) and not S.status.force_continue
83 if converged is None:
84 converged = False
86 # print information for debugging
87 if converged and self:
88 self.debug(
89 f'Declared convergence: maxiter reached[{"x" if iter_converged else " "}] restol reached[{"x" if res_converged else " "}] e_tol reached[{"x" if e_tol_converged else " "}]',
90 S,
91 )
92 return converged
94 def check_iteration_status(self, controller, S, **kwargs):
95 """
96 Routine to determine whether to stop iterating (currently testing the residual + the max. number of iterations)
98 Args:
99 controller (pySDC.Controller.controller): The controller
100 S (pySDC.Step.step): The current step
102 Returns:
103 None
104 """
105 S.status.done = self.check_convergence(S, self)
107 if "comm" in kwargs.keys():
108 self.communicate_convergence(controller, S, **kwargs)
110 S.status.force_continue = False
112 return None
114 def communicate_convergence(self, controller, S, comm):
115 """
116 Communicate the convergence status during `check_iteration_status` if MPI is used.
118 Args:
119 controller (pySDC.Controller): The controller
120 S (pySDC.Step.step): The current step
121 comm (mpi4py.MPI.Comm): MPI communicator
123 Returns:
124 None
125 """
126 # Either gather information about all status or send forward own
127 if controller.params.all_to_done:
128 for hook in controller.hooks:
129 hook.pre_comm(step=S, level_number=0)
130 S.status.done = comm.allreduce(sendobj=S.status.done, op=self.MPI_LAND)
131 S.status.force_done = comm.allreduce(sendobj=S.status.force_done, op=self.MPI_LOR)
132 for hook in controller.hooks:
133 hook.post_comm(step=S, level_number=0, add_to_stats=True)
135 S.status.done = S.status.done or S.status.force_done
137 else:
138 for hook in controller.hooks:
139 hook.pre_comm(step=S, level_number=0)
141 # check if an open request of the status send is pending
142 controller.wait_with_interrupt(request=controller.req_status)
143 if S.status.force_done:
144 return None
146 # recv status
147 if not S.status.first and not S.status.prev_done:
148 buff = np.empty(1, dtype=bool)
149 self.Recv(comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
150 S.status.prev_done = buff[0]
151 S.status.done = S.status.done and S.status.prev_done
153 # send status forward
154 if not S.status.last:
155 buff = np.empty(1, dtype=bool)
156 buff[0] = S.status.done
157 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])
159 for hook in controller.hooks:
160 hook.post_comm(step=S, level_number=0, add_to_stats=True)