Coverage for pySDC/projects/Resilience/Lorenz.py: 94%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1# script to run a Lorenz attractor problem 

2import numpy as np 

3import matplotlib.pyplot as plt 

4 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.implementations.problem_classes.Lorenz import LorenzAttractor 

7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

8from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

9from pySDC.core.errors import ConvergenceError 

10from pySDC.projects.Resilience.hook import LogData, hook_collection 

11from pySDC.projects.Resilience.strategies import merge_descriptions 

12from pySDC.projects.Resilience.sweepers import generic_implicit_efficient 

13 

14 

15def run_Lorenz( 

16 custom_description=None, 

17 num_procs=1, 

18 Tend=1.0, 

19 hook_class=LogData, 

20 fault_stuff=None, 

21 custom_controller_params=None, 

22 use_MPI=False, 

23 **kwargs, 

24): 

25 """ 

26 Run a Lorenz attractor problem with default parameters. 

27 

28 Args: 

29 custom_description (dict): Overwrite presets 

30 num_procs (int): Number of steps for MSSDC 

31 Tend (float): Time to integrate to 

32 hook_class (pySDC.Hook): A hook to store data 

33 fault_stuff (dict): A dictionary with information on how to add faults 

34 custom_controller_params (dict): Overwrite presets 

35 use_MPI (bool): Whether or not to use MPI 

36 

37 Returns: 

38 dict: The stats object 

39 controller: The controller 

40 bool: Whether the code crashed 

41 """ 

42 

43 # initialize level parameters 

44 level_params = dict() 

45 level_params['dt'] = 1e-2 

46 

47 # initialize sweeper parameters 

48 sweeper_params = dict() 

49 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

50 sweeper_params['num_nodes'] = 3 

51 sweeper_params['QI'] = 'IE' 

52 

53 problem_params = { 

54 'newton_tol': 1e-9, 

55 'newton_maxiter': 99, 

56 } 

57 

58 # initialize step parameters 

59 step_params = dict() 

60 step_params['maxiter'] = 4 

61 

62 # initialize controller parameters 

63 controller_params = dict() 

64 controller_params['logger_level'] = 30 

65 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) 

66 controller_params['mssdc_jac'] = False 

67 

68 if custom_controller_params is not None: 

69 controller_params = {**controller_params, **custom_controller_params} 

70 

71 # fill description dictionary for easy step instantiation 

72 description = dict() 

73 description['problem_class'] = LorenzAttractor 

74 description['problem_params'] = problem_params 

75 description['sweeper_class'] = generic_implicit_efficient 

76 description['sweeper_params'] = sweeper_params 

77 description['level_params'] = level_params 

78 description['step_params'] = step_params 

79 

80 if custom_description is not None: 

81 description = merge_descriptions(description, custom_description) 

82 

83 # set time parameters 

84 t0 = 0.0 

85 

86 # instantiate controller 

87 if use_MPI: 

88 from mpi4py import MPI 

89 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

90 

91 comm = kwargs.get('comm', MPI.COMM_WORLD) 

92 controller = controller_MPI(controller_params=controller_params, description=description, comm=comm) 

93 P = controller.S.levels[0].prob 

94 else: 

95 controller = controller_nonMPI( 

96 num_procs=num_procs, controller_params=controller_params, description=description 

97 ) 

98 P = controller.MS[0].levels[0].prob 

99 uinit = P.u_exact(t0) 

100 

101 # insert faults 

102 if fault_stuff is not None: 

103 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

104 

105 prepare_controller_for_faults(controller, fault_stuff) 

106 

107 # call main function to get things done... 

108 crash = False 

109 try: 

110 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

111 except (ConvergenceError, ZeroDivisionError) as e: 

112 print(f'Warning: Premature termination: {e}') 

113 stats = controller.return_stats() 

114 crash = True 

115 

116 return stats, controller, crash 

117 

118 

119def plot_solution(stats): # pragma: no cover 

120 """ 

121 Plot the solution in 3D. 

122 

123 Args: 

124 stats (dict): The stats object of the run 

125 

126 Returns: 

127 None 

128 """ 

129 fig = plt.figure() 

130 ax = fig.add_subplot(projection='3d') 

131 

132 u = get_sorted(stats, type='u') 

133 ax.plot([me[1][0] for me in u], [me[1][1] for me in u], [me[1][2] for me in u]) 

134 ax.set_xlabel('x') 

135 ax.set_ylabel('y') 

136 ax.set_zlabel('z') 

137 plt.show() 

138 

139 

140def check_solution(stats, controller, thresh=5e-4): 

141 """ 

142 Check if the global error solution wrt. a scipy reference solution is tolerable. 

143 This is also a check for the global error hook. 

144 

145 Args: 

146 stats (dict): The stats object of the run 

147 controller (pySDC.Controller.controller): The controller 

148 thresh (float): Threshold for accepting the accuracy 

149 

150 Returns: 

151 None 

152 """ 

153 u = get_sorted(stats, type='u') 

154 u_exact = controller.MS[0].levels[0].prob.u_exact(t=u[-1][0]) 

155 error = np.linalg.norm(u[-1][1] - u_exact, np.inf) 

156 error_hook = get_sorted(stats, type='e_global_post_run')[-1][1] 

157 

158 assert error == error_hook, f'Expected errors to match, got {error:.2e} and {error_hook:.2e}!' 

159 assert error < thresh, f"Error too large, got e={error:.2e}" 

160 

161 

162def main(plotting=True): 

163 """ 

164 Make a test run and see if the accuracy checks out. 

165 

166 Args: 

167 plotting (bool): Plot the solution or not 

168 

169 Returns: 

170 None 

171 """ 

172 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun 

173 

174 custom_description = {} 

175 custom_description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-5}} 

176 custom_controller_params = {'logger_level': 30} 

177 stats, controller, _ = run_Lorenz( 

178 custom_description=custom_description, 

179 custom_controller_params=custom_controller_params, 

180 Tend=10.0, 

181 hook_class=[LogData, LogGlobalErrorPostRun], 

182 ) 

183 check_solution(stats, controller, 5e-4) 

184 if plotting: # pragma: no cover 

185 plot_solution(stats) 

186 

187 

188if __name__ == "__main__": 

189 main()