Coverage for pySDC/projects/Resilience/Lorenz.py: 94%
71 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 13:12 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 13:12 +0000
1# script to run a Lorenz attractor problem
2import numpy as np
3import matplotlib.pyplot as plt
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.implementations.problem_classes.Lorenz import LorenzAttractor
7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
8from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
9from pySDC.core.errors import ConvergenceError
10from pySDC.projects.Resilience.hook import LogData, hook_collection
11from pySDC.projects.Resilience.strategies import merge_descriptions
12from pySDC.projects.Resilience.sweepers import generic_implicit_efficient
15def run_Lorenz(
16 custom_description=None,
17 num_procs=1,
18 Tend=1.0,
19 hook_class=LogData,
20 fault_stuff=None,
21 custom_controller_params=None,
22 use_MPI=False,
23 **kwargs,
24):
25 """
26 Run a Lorenz attractor problem with default parameters.
28 Args:
29 custom_description (dict): Overwrite presets
30 num_procs (int): Number of steps for MSSDC
31 Tend (float): Time to integrate to
32 hook_class (pySDC.Hook): A hook to store data
33 fault_stuff (dict): A dictionary with information on how to add faults
34 custom_controller_params (dict): Overwrite presets
35 use_MPI (bool): Whether or not to use MPI
37 Returns:
38 dict: The stats object
39 controller: The controller
40 bool: Whether the code crashed
41 """
43 # initialize level parameters
44 level_params = dict()
45 level_params['dt'] = 1e-2
47 # initialize sweeper parameters
48 sweeper_params = dict()
49 sweeper_params['quad_type'] = 'RADAU-RIGHT'
50 sweeper_params['num_nodes'] = 3
51 sweeper_params['QI'] = 'IE'
52 sweeper_params['initial_guess'] = 'copy'
54 problem_params = {
55 'newton_tol': 1e-9,
56 'newton_maxiter': 99,
57 }
59 # initialize step parameters
60 step_params = dict()
61 step_params['maxiter'] = 4
63 # initialize controller parameters
64 controller_params = dict()
65 controller_params['logger_level'] = 30
66 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
67 controller_params['mssdc_jac'] = False
69 if custom_controller_params is not None:
70 controller_params = {**controller_params, **custom_controller_params}
72 # fill description dictionary for easy step instantiation
73 description = dict()
74 description['problem_class'] = LorenzAttractor
75 description['problem_params'] = problem_params
76 description['sweeper_class'] = generic_implicit_efficient
77 description['sweeper_params'] = sweeper_params
78 description['level_params'] = level_params
79 description['step_params'] = step_params
81 if custom_description is not None:
82 description = merge_descriptions(description, custom_description)
84 # set time parameters
85 t0 = 0.0
87 # instantiate controller
88 if use_MPI:
89 from mpi4py import MPI
90 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
92 comm = kwargs.get('comm', MPI.COMM_WORLD)
93 controller = controller_MPI(controller_params=controller_params, description=description, comm=comm)
94 P = controller.S.levels[0].prob
95 else:
96 controller = controller_nonMPI(
97 num_procs=num_procs, controller_params=controller_params, description=description
98 )
99 P = controller.MS[0].levels[0].prob
100 uinit = P.u_exact(t0)
102 # insert faults
103 if fault_stuff is not None:
104 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
106 prepare_controller_for_faults(controller, fault_stuff)
108 # call main function to get things done...
109 crash = False
110 try:
111 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
112 except (ConvergenceError, ZeroDivisionError) as e:
113 print(f'Warning: Premature termination: {e}')
114 stats = controller.return_stats()
115 crash = True
117 return stats, controller, crash
120def plot_solution(stats): # pragma: no cover
121 """
122 Plot the solution in 3D.
124 Args:
125 stats (dict): The stats object of the run
127 Returns:
128 None
129 """
130 fig = plt.figure()
131 ax = fig.add_subplot(projection='3d')
133 u = get_sorted(stats, type='u')
134 ax.plot([me[1][0] for me in u], [me[1][1] for me in u], [me[1][2] for me in u])
135 ax.set_xlabel('x')
136 ax.set_ylabel('y')
137 ax.set_zlabel('z')
138 plt.show()
141def check_solution(stats, controller, thresh=5e-4):
142 """
143 Check if the global error solution wrt. a scipy reference solution is tolerable.
144 This is also a check for the global error hook.
146 Args:
147 stats (dict): The stats object of the run
148 controller (pySDC.Controller.controller): The controller
149 thresh (float): Threshold for accepting the accuracy
151 Returns:
152 None
153 """
154 u = get_sorted(stats, type='u')
155 u_exact = controller.MS[0].levels[0].prob.u_exact(t=u[-1][0])
156 error = np.linalg.norm(u[-1][1] - u_exact, np.inf)
157 error_hook = get_sorted(stats, type='e_global_post_run')[-1][1]
159 assert error == error_hook, f'Expected errors to match, got {error:.2e} and {error_hook:.2e}!'
160 assert error < thresh, f"Error too large, got e={error:.2e}"
163def main(plotting=True):
164 """
165 Make a test run and see if the accuracy checks out.
167 Args:
168 plotting (bool): Plot the solution or not
170 Returns:
171 None
172 """
173 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
175 custom_description = {}
176 custom_description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-5}}
177 custom_controller_params = {'logger_level': 30}
178 stats, controller, _ = run_Lorenz(
179 custom_description=custom_description,
180 custom_controller_params=custom_controller_params,
181 Tend=10.0,
182 hook_class=[LogData, LogGlobalErrorPostRun],
183 )
184 check_solution(stats, controller, 5e-4)
185 if plotting: # pragma: no cover
186 plot_solution(stats)
189if __name__ == "__main__":
190 main()