Coverage for pySDC/projects/Resilience/Lorenz.py: 94%
70 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1# script to run a Lorenz attractor problem
2import numpy as np
3import matplotlib.pyplot as plt
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.implementations.problem_classes.Lorenz import LorenzAttractor
7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
8from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
9from pySDC.core.errors import ConvergenceError
10from pySDC.projects.Resilience.hook import LogData, hook_collection
11from pySDC.projects.Resilience.strategies import merge_descriptions
12from pySDC.projects.Resilience.sweepers import generic_implicit_efficient
15def run_Lorenz(
16 custom_description=None,
17 num_procs=1,
18 Tend=1.0,
19 hook_class=LogData,
20 fault_stuff=None,
21 custom_controller_params=None,
22 use_MPI=False,
23 **kwargs,
24):
25 """
26 Run a Lorenz attractor problem with default parameters.
28 Args:
29 custom_description (dict): Overwrite presets
30 num_procs (int): Number of steps for MSSDC
31 Tend (float): Time to integrate to
32 hook_class (pySDC.Hook): A hook to store data
33 fault_stuff (dict): A dictionary with information on how to add faults
34 custom_controller_params (dict): Overwrite presets
35 use_MPI (bool): Whether or not to use MPI
37 Returns:
38 dict: The stats object
39 controller: The controller
40 bool: Whether the code crashed
41 """
43 # initialize level parameters
44 level_params = dict()
45 level_params['dt'] = 1e-2
47 # initialize sweeper parameters
48 sweeper_params = dict()
49 sweeper_params['quad_type'] = 'RADAU-RIGHT'
50 sweeper_params['num_nodes'] = 3
51 sweeper_params['QI'] = 'IE'
53 problem_params = {
54 'newton_tol': 1e-9,
55 'newton_maxiter': 99,
56 }
58 # initialize step parameters
59 step_params = dict()
60 step_params['maxiter'] = 4
62 # initialize controller parameters
63 controller_params = dict()
64 controller_params['logger_level'] = 30
65 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
66 controller_params['mssdc_jac'] = False
68 if custom_controller_params is not None:
69 controller_params = {**controller_params, **custom_controller_params}
71 # fill description dictionary for easy step instantiation
72 description = dict()
73 description['problem_class'] = LorenzAttractor
74 description['problem_params'] = problem_params
75 description['sweeper_class'] = generic_implicit_efficient
76 description['sweeper_params'] = sweeper_params
77 description['level_params'] = level_params
78 description['step_params'] = step_params
80 if custom_description is not None:
81 description = merge_descriptions(description, custom_description)
83 # set time parameters
84 t0 = 0.0
86 # instantiate controller
87 if use_MPI:
88 from mpi4py import MPI
89 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
91 comm = kwargs.get('comm', MPI.COMM_WORLD)
92 controller = controller_MPI(controller_params=controller_params, description=description, comm=comm)
93 P = controller.S.levels[0].prob
94 else:
95 controller = controller_nonMPI(
96 num_procs=num_procs, controller_params=controller_params, description=description
97 )
98 P = controller.MS[0].levels[0].prob
99 uinit = P.u_exact(t0)
101 # insert faults
102 if fault_stuff is not None:
103 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
105 prepare_controller_for_faults(controller, fault_stuff)
107 # call main function to get things done...
108 crash = False
109 try:
110 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
111 except (ConvergenceError, ZeroDivisionError) as e:
112 print(f'Warning: Premature termination: {e}')
113 stats = controller.return_stats()
114 crash = True
116 return stats, controller, crash
119def plot_solution(stats): # pragma: no cover
120 """
121 Plot the solution in 3D.
123 Args:
124 stats (dict): The stats object of the run
126 Returns:
127 None
128 """
129 fig = plt.figure()
130 ax = fig.add_subplot(projection='3d')
132 u = get_sorted(stats, type='u')
133 ax.plot([me[1][0] for me in u], [me[1][1] for me in u], [me[1][2] for me in u])
134 ax.set_xlabel('x')
135 ax.set_ylabel('y')
136 ax.set_zlabel('z')
137 plt.show()
140def check_solution(stats, controller, thresh=5e-4):
141 """
142 Check if the global error solution wrt. a scipy reference solution is tolerable.
143 This is also a check for the global error hook.
145 Args:
146 stats (dict): The stats object of the run
147 controller (pySDC.Controller.controller): The controller
148 thresh (float): Threshold for accepting the accuracy
150 Returns:
151 None
152 """
153 u = get_sorted(stats, type='u')
154 u_exact = controller.MS[0].levels[0].prob.u_exact(t=u[-1][0])
155 error = np.linalg.norm(u[-1][1] - u_exact, np.inf)
156 error_hook = get_sorted(stats, type='e_global_post_run')[-1][1]
158 assert error == error_hook, f'Expected errors to match, got {error:.2e} and {error_hook:.2e}!'
159 assert error < thresh, f"Error too large, got e={error:.2e}"
162def main(plotting=True):
163 """
164 Make a test run and see if the accuracy checks out.
166 Args:
167 plotting (bool): Plot the solution or not
169 Returns:
170 None
171 """
172 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
174 custom_description = {}
175 custom_description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-5}}
176 custom_controller_params = {'logger_level': 30}
177 stats, controller, _ = run_Lorenz(
178 custom_description=custom_description,
179 custom_controller_params=custom_controller_params,
180 Tend=10.0,
181 hook_class=[LogData, LogGlobalErrorPostRun],
182 )
183 check_solution(stats, controller, 5e-4)
184 if plotting: # pragma: no cover
185 plot_solution(stats)
188if __name__ == "__main__":
189 main()