Coverage for pySDC/implementations/convergence_controller_classes/check_convergence.py: 84%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1import numpy as np 

2 

3from pySDC.core.convergence_controller import ConvergenceController 

4 

5 

6class CheckConvergence(ConvergenceController): 

7 """ 

8 Perform simple checks on convergence for SDC iterations. 

9 

10 Iteration is terminated via one of two criteria: 

11 - Residual tolerance 

12 - Maximum number of iterations 

13 """ 

14 

15 def setup(self, controller, params, description, **kwargs): 

16 """ 

17 Define default parameters here 

18 

19 Args: 

20 controller (pySDC.Controller): The controller 

21 params (dict): The params passed for this specific convergence controller 

22 description (dict): The description object used to instantiate the controller 

23 

24 Returns: 

25 (dict): The updated params dictionary 

26 """ 

27 defaults = {'control_order': +200, 'use_e_tol': 'e_tol' in description['level_params'].keys()} 

28 

29 return {**defaults, **super().setup(controller, params, description, **kwargs)} 

30 

31 def dependencies(self, controller, description, **kwargs): 

32 """ 

33 Load the embedded error estimator if needed. 

34 

35 Args: 

36 controller (pySDC.Controller): The controller 

37 description (dict): The description object used to instantiate the controller 

38 

39 Returns: 

40 None 

41 """ 

42 if self.params.useMPI: 

43 self.prepare_MPI_logical_operations() 

44 

45 super().dependencies(controller, description) 

46 

47 if self.params.use_e_tol: 

48 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import ( 

49 EstimateEmbeddedError, 

50 ) 

51 

52 controller.add_convergence_controller( 

53 EstimateEmbeddedError, 

54 description=description, 

55 ) 

56 

57 return None 

58 

59 @staticmethod 

60 def check_convergence(S, self=None): 

61 """ 

62 Check the convergence of a single step. 

63 Test the residual and max. number of iterations as well as allowing overrides to both stop and continue. 

64 

65 Args: 

66 S (pySDC.Step): The current step 

67 

68 Returns: 

69 bool: Convergence status of the step 

70 """ 

71 # do all this on the finest level 

72 L = S.levels[0] 

73 

74 # get residual and check against prescribed tolerance (plus check number of iterations) 

75 iter_converged = S.status.iter >= S.params.maxiter 

76 res_converged = L.status.residual <= L.params.restol 

77 e_tol_converged = ( 

78 L.status.increment < L.params.e_tol if (L.params.get('e_tol') and L.status.get('increment')) else False 

79 ) 

80 converged = ( 

81 iter_converged or res_converged or e_tol_converged or S.status.force_done 

82 ) and not S.status.force_continue 

83 if converged is None: 

84 converged = False 

85 

86 # print information for debugging 

87 if converged and self: 

88 self.debug( 

89 f'Declared convergence: maxiter reached[{"x" if iter_converged else " "}] restol reached[{"x" if res_converged else " "}] e_tol reached[{"x" if e_tol_converged else " "}]', 

90 S, 

91 ) 

92 return converged 

93 

94 def check_iteration_status(self, controller, S, **kwargs): 

95 """ 

96 Routine to determine whether to stop iterating (currently testing the residual + the max. number of iterations) 

97 

98 Args: 

99 controller (pySDC.Controller.controller): The controller 

100 S (pySDC.Step.step): The current step 

101 

102 Returns: 

103 None 

104 """ 

105 S.status.done = self.check_convergence(S, self) 

106 

107 if "comm" in kwargs.keys(): 

108 self.communicate_convergence(controller, S, **kwargs) 

109 

110 S.status.force_continue = False 

111 

112 return None 

113 

114 def communicate_convergence(self, controller, S, comm): 

115 """ 

116 Communicate the convergence status during `check_iteration_status` if MPI is used. 

117 

118 Args: 

119 controller (pySDC.Controller): The controller 

120 S (pySDC.Step.step): The current step 

121 comm (mpi4py.MPI.Comm): MPI communicator 

122 

123 Returns: 

124 None 

125 """ 

126 # Either gather information about all status or send forward own 

127 if controller.params.all_to_done: 

128 for hook in controller.hooks: 

129 hook.pre_comm(step=S, level_number=0) 

130 S.status.done = comm.allreduce(sendobj=S.status.done, op=self.MPI_LAND) 

131 S.status.force_done = comm.allreduce(sendobj=S.status.force_done, op=self.MPI_LOR) 

132 for hook in controller.hooks: 

133 hook.post_comm(step=S, level_number=0, add_to_stats=True) 

134 

135 S.status.done = S.status.done or S.status.force_done 

136 

137 else: 

138 for hook in controller.hooks: 

139 hook.pre_comm(step=S, level_number=0) 

140 

141 # check if an open request of the status send is pending 

142 controller.wait_with_interrupt(request=controller.req_status) 

143 if S.status.force_done: 

144 return None 

145 

146 # recv status 

147 if not S.status.first and not S.status.prev_done: 

148 buff = np.empty(1, dtype=bool) 

149 self.Recv(comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL]) 

150 S.status.prev_done = buff[0] 

151 S.status.done = S.status.done and S.status.prev_done 

152 

153 # send status forward 

154 if not S.status.last: 

155 buff = np.empty(1, dtype=bool) 

156 buff[0] = S.status.done 

157 self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL]) 

158 

159 for hook in controller.hooks: 

160 hook.post_comm(step=S, level_number=0, add_to_stats=True)