Coverage for pySDC/implementations/convergence_controller_classes/check_convergence.py: 84%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import numpy as np 

2 

3from pySDC.core.convergence_controller import ConvergenceController 

4 

5 

6class CheckConvergence(ConvergenceController): 

7 """ 

8 Perform simple checks on convergence for SDC iterations. 

9 

10 Iteration is terminated via one of two criteria: 

11 - Residual tolerance 

12 - Maximum number of iterations 

13 """ 

14 

15 def setup(self, controller, params, description, **kwargs): 

16 """ 

17 Define default parameters here 

18 

19 Args: 

20 controller (pySDC.Controller): The controller 

21 params (dict): The params passed for this specific convergence controller 

22 description (dict): The description object used to instantiate the controller 

23 

24 Returns: 

25 (dict): The updated params dictionary 

26 """ 

27 defaults = {'control_order': +200, 'use_e_tol': 'e_tol' in description['level_params'].keys()} 

28 

29 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

30 

31 def dependencies(self, controller, description, **kwargs): 

32 """ 

33 Load the embedded error estimator if needed. 

34 

35 Args: 

36 controller (pySDC.Controller): The controller 

37 description (dict): The description object used to instantiate the controller 

38 

39 Returns: 

40 None 

41 """ 

42 if self.params.useMPI: 

43 self.prepare_MPI_logical_operations() 

44 

45 super().dependencies(controller, description) 

46 

47 if self.params.use_e_tol: 

48 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import ( 

49 EstimateEmbeddedError, 

50 ) 

51 

52 controller.add_convergence_controller( 

53 EstimateEmbeddedError, 

54 description=description, 

55 ) 

56 

57 return None 

58 

59 @staticmethod 

60 def check_convergence(S, self=None): 

61 """ 

62 Check the convergence of a single step. 

63 Test the residual and max. number of iterations as well as allowing overrides to both stop and continue. 

64 

65 Args: 

66 S (pySDC.Step): The current step 

67 

68 Returns: 

69 bool: Convergence status of the step 

70 """ 

71 # do all this on the finest level 

72 L = S.levels[0] 

73 

74 # get residual and check against prescribed tolerance (plus check number of iterations) 

75 iter_converged = S.status.iter >= S.params.maxiter 

76 res_converged = L.status.residual <= L.params.restol 

77 e_tol_converged = ( 

78 L.status.error_embedded_estimate < L.params.e_tol 

79 if (L.params.get('e_tol') and L.status.get('error_embedded_estimate')) 

80 else False 

81 ) 

82 converged = ( 

83 iter_converged or res_converged or e_tol_converged or S.status.force_done 

84 ) and not S.status.force_continue 

85 if converged is None: 

86 converged = False 

87 

88 # print information for debugging 

89 if converged and self: 

90 self.debug( 

91 f'Declared convergence: maxiter reached[{"x" if iter_converged else " "}] restol reached[{"x" if res_converged else " "}] e_tol reached[{"x" if e_tol_converged else " "}]', 

92 S, 

93 ) 

94 return converged 

95 

96 def check_iteration_status(self, controller, S, **kwargs): 

97 """ 

98 Routine to determine whether to stop iterating (currently testing the residual + the max. number of iterations) 

99 

100 Args: 

101 controller (pySDC.Controller.controller): The controller 

102 S (pySDC.Step.step): The current step 

103 

104 Returns: 

105 None 

106 """ 

107 S.status.done = self.check_convergence(S, self) 

108 

109 if "comm" in kwargs.keys(): 

110 self.communicate_convergence(controller, S, **kwargs) 

111 

112 S.status.force_continue = False 

113 

114 return None 

115 

116 def communicate_convergence(self, controller, S, comm): 

117 """ 

118 Communicate the convergence status during `check_iteration_status` if MPI is used. 

119 

120 Args: 

121 controller (pySDC.Controller): The controller 

122 S (pySDC.Step.step): The current step 

123 comm (mpi4py.MPI.Comm): MPI communicator 

124 

125 Returns: 

126 None 

127 """ 

128 # Either gather information about all status or send forward own 

129 if controller.params.all_to_done: 

130 for hook in controller.hooks: 

131 hook.pre_comm(step=S, level_number=0) 

132 S.status.done = comm.allreduce(sendobj=S.status.done, op=self.MPI_LAND) 

133 S.status.force_done = comm.allreduce(sendobj=S.status.force_done, op=self.MPI_LOR) 

134 for hook in controller.hooks: 

135 hook.post_comm(step=S, level_number=0, add_to_stats=True) 

136 

137 S.status.done = S.status.done or S.status.force_done 

138 

139 else: 

140 for hook in controller.hooks: 

141 hook.pre_comm(step=S, level_number=0) 

142 

143 # check if an open request of the status send is pending 

144 controller.wait_with_interrupt(request=controller.req_status) 

145 if S.status.force_done: 

146 return None 

147 

148 # recv status 

149 if not S.status.first and not S.status.prev_done: 

150 buff = np.empty(1, dtype=bool) 

151 self.Recv(comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL]) 

152 S.status.prev_done = buff[0] 

153 S.status.done = S.status.done and S.status.prev_done 

154 

155 # send status forward 

156 if not S.status.last: 

157 buff = np.empty(1, dtype=bool) 

158 buff[0] = S.status.done 

159 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL]) 

160 

161 for hook in controller.hooks: 

162 hook.post_comm(step=S, level_number=0, add_to_stats=True)