Coverage for pySDC/projects/Resilience/Lorenz.py: 94%

71 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 13:12 +0000

1# script to run a Lorenz attractor problem 

2import numpy as np 

3import matplotlib.pyplot as plt 

4 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.implementations.problem_classes.Lorenz import LorenzAttractor 

7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

8from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

9from pySDC.core.errors import ConvergenceError 

10from pySDC.projects.Resilience.hook import LogData, hook_collection 

11from pySDC.projects.Resilience.strategies import merge_descriptions 

12from pySDC.projects.Resilience.sweepers import generic_implicit_efficient 

13 

14 

15def run_Lorenz( 

16 custom_description=None, 

17 num_procs=1, 

18 Tend=1.0, 

19 hook_class=LogData, 

20 fault_stuff=None, 

21 custom_controller_params=None, 

22 use_MPI=False, 

23 **kwargs, 

24): 

25 """ 

26 Run a Lorenz attractor problem with default parameters. 

27 

28 Args: 

29 custom_description (dict): Overwrite presets 

30 num_procs (int): Number of steps for MSSDC 

31 Tend (float): Time to integrate to 

32 hook_class (pySDC.Hook): A hook to store data 

33 fault_stuff (dict): A dictionary with information on how to add faults 

34 custom_controller_params (dict): Overwrite presets 

35 use_MPI (bool): Whether or not to use MPI 

36 

37 Returns: 

38 dict: The stats object 

39 controller: The controller 

40 bool: Whether the code crashed 

41 """ 

42 

43 # initialize level parameters 

44 level_params = dict() 

45 level_params['dt'] = 1e-2 

46 

47 # initialize sweeper parameters 

48 sweeper_params = dict() 

49 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

50 sweeper_params['num_nodes'] = 3 

51 sweeper_params['QI'] = 'IE' 

52 sweeper_params['initial_guess'] = 'copy' 

53 

54 problem_params = { 

55 'newton_tol': 1e-9, 

56 'newton_maxiter': 99, 

57 } 

58 

59 # initialize step parameters 

60 step_params = dict() 

61 step_params['maxiter'] = 4 

62 

63 # initialize controller parameters 

64 controller_params = dict() 

65 controller_params['logger_level'] = 30 

66 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) 

67 controller_params['mssdc_jac'] = False 

68 

69 if custom_controller_params is not None: 

70 controller_params = {**controller_params, **custom_controller_params} 

71 

72 # fill description dictionary for easy step instantiation 

73 description = dict() 

74 description['problem_class'] = LorenzAttractor 

75 description['problem_params'] = problem_params 

76 description['sweeper_class'] = generic_implicit_efficient 

77 description['sweeper_params'] = sweeper_params 

78 description['level_params'] = level_params 

79 description['step_params'] = step_params 

80 

81 if custom_description is not None: 

82 description = merge_descriptions(description, custom_description) 

83 

84 # set time parameters 

85 t0 = 0.0 

86 

87 # instantiate controller 

88 if use_MPI: 

89 from mpi4py import MPI 

90 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

91 

92 comm = kwargs.get('comm', MPI.COMM_WORLD) 

93 controller = controller_MPI(controller_params=controller_params, description=description, comm=comm) 

94 P = controller.S.levels[0].prob 

95 else: 

96 controller = controller_nonMPI( 

97 num_procs=num_procs, controller_params=controller_params, description=description 

98 ) 

99 P = controller.MS[0].levels[0].prob 

100 uinit = P.u_exact(t0) 

101 

102 # insert faults 

103 if fault_stuff is not None: 

104 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

105 

106 prepare_controller_for_faults(controller, fault_stuff) 

107 

108 # call main function to get things done... 

109 crash = False 

110 try: 

111 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

112 except (ConvergenceError, ZeroDivisionError) as e: 

113 print(f'Warning: Premature termination: {e}') 

114 stats = controller.return_stats() 

115 crash = True 

116 

117 return stats, controller, crash 

118 

119 

120def plot_solution(stats): # pragma: no cover 

121 """ 

122 Plot the solution in 3D. 

123 

124 Args: 

125 stats (dict): The stats object of the run 

126 

127 Returns: 

128 None 

129 """ 

130 fig = plt.figure() 

131 ax = fig.add_subplot(projection='3d') 

132 

133 u = get_sorted(stats, type='u') 

134 ax.plot([me[1][0] for me in u], [me[1][1] for me in u], [me[1][2] for me in u]) 

135 ax.set_xlabel('x') 

136 ax.set_ylabel('y') 

137 ax.set_zlabel('z') 

138 plt.show() 

139 

140 

141def check_solution(stats, controller, thresh=5e-4): 

142 """ 

143 Check if the global error solution wrt. a scipy reference solution is tolerable. 

144 This is also a check for the global error hook. 

145 

146 Args: 

147 stats (dict): The stats object of the run 

148 controller (pySDC.Controller.controller): The controller 

149 thresh (float): Threshold for accepting the accuracy 

150 

151 Returns: 

152 None 

153 """ 

154 u = get_sorted(stats, type='u') 

155 u_exact = controller.MS[0].levels[0].prob.u_exact(t=u[-1][0]) 

156 error = np.linalg.norm(u[-1][1] - u_exact, np.inf) 

157 error_hook = get_sorted(stats, type='e_global_post_run')[-1][1] 

158 

159 assert error == error_hook, f'Expected errors to match, got {error:.2e} and {error_hook:.2e}!' 

160 assert error < thresh, f"Error too large, got e={error:.2e}" 

161 

162 

163def main(plotting=True): 

164 """ 

165 Make a test run and see if the accuracy checks out. 

166 

167 Args: 

168 plotting (bool): Plot the solution or not 

169 

170 Returns: 

171 None 

172 """ 

173 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun 

174 

175 custom_description = {} 

176 custom_description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-5}} 

177 custom_controller_params = {'logger_level': 30} 

178 stats, controller, _ = run_Lorenz( 

179 custom_description=custom_description, 

180 custom_controller_params=custom_controller_params, 

181 Tend=10.0, 

182 hook_class=[LogData, LogGlobalErrorPostRun], 

183 ) 

184 check_solution(stats, controller, 5e-4) 

185 if plotting: # pragma: no cover 

186 plot_solution(stats) 

187 

188 

189if __name__ == "__main__": 

190 main()