Coverage for pySDC/projects/Resilience/piline.py: 87%
141 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1import numpy as np
2import matplotlib.pyplot as plt
4from pySDC.helpers.stats_helper import get_sorted
5from pySDC.implementations.problem_classes.Piline import piline
6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
8from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
9from pySDC.implementations.convergence_controller_classes.hotrod import HotRod
10from pySDC.projects.Resilience.hook import LogData, hook_collection
11from pySDC.projects.Resilience.strategies import merge_descriptions
14def run_piline(
15 custom_description=None,
16 num_procs=1,
17 Tend=20.0,
18 hook_class=LogData,
19 fault_stuff=None,
20 custom_controller_params=None,
21):
22 """
23 Run a Piline problem with default parameters.
25 Args:
26 custom_description (dict): Overwrite presets
27 num_procs (int): Number of steps for MSSDC
28 Tend (float): Time to integrate to
29 hook_class (pySDC.Hook): A hook to store data
30 fault_stuff (dict): A dictionary with information on how to add faults
31 custom_controller_params (dict): Overwrite presets
33 Returns:
34 dict: The stats object
35 controller: The controller
36 Tend: The time that was supposed to be integrated to
37 """
39 # initialize level parameters
40 level_params = dict()
41 level_params['dt'] = 5e-2
43 # initialize sweeper parameters
44 sweeper_params = dict()
45 sweeper_params['quad_type'] = 'RADAU-RIGHT'
46 sweeper_params['num_nodes'] = 3
47 sweeper_params['QI'] = 'IE'
48 sweeper_params['QE'] = 'PIC'
50 problem_params = {
51 'Vs': 100.0,
52 'Rs': 1.0,
53 'C1': 1.0,
54 'Rpi': 0.2,
55 'C2': 1.0,
56 'Lpi': 1.0,
57 'Rl': 5.0,
58 }
60 # initialize step parameters
61 step_params = dict()
62 step_params['maxiter'] = 4
64 # initialize controller parameters
65 controller_params = dict()
66 controller_params['logger_level'] = 30
67 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
68 controller_params['mssdc_jac'] = False
70 if custom_controller_params is not None:
71 controller_params = {**controller_params, **custom_controller_params}
73 # fill description dictionary for easy step instantiation
74 description = dict()
75 description['problem_class'] = piline # pass problem class
76 description['problem_params'] = problem_params # pass problem parameters
77 description['sweeper_class'] = imex_1st_order # pass sweeper
78 description['sweeper_params'] = sweeper_params # pass sweeper parameters
79 description['level_params'] = level_params # pass level parameters
80 description['step_params'] = step_params
82 if custom_description is not None:
83 description = merge_descriptions(description, custom_description)
85 # set time parameters
86 t0 = 0.0
88 # instantiate controller
89 controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description)
91 # insert faults
92 if fault_stuff is not None:
93 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
95 rnd_args = {'iteration': 4}
96 args = {'time': 2.5, 'target': 0}
97 prepare_controller_for_faults(controller, fault_stuff, rnd_args, args)
99 # get initial values on finest level
100 P = controller.MS[0].levels[0].prob
101 uinit = P.u_exact(t0)
103 # call main function to get things done...
104 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
105 return stats, controller, Tend
108def get_data(stats, recomputed=False):
109 """
110 Extract useful data from the stats.
112 Args:
113 stats (pySDC.stats): The stats object of the run
114 recomputed (bool): Whether to exclude values that don't contribute to the final solution or not
116 Returns:
117 dict: Data
118 """
119 data = {
120 'v1': np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed)]),
121 'v2': np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed)]),
122 'p3': np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=recomputed)]),
123 't': np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)]),
124 'dt': np.array([me[1] for me in get_sorted(stats, type='dt', recomputed=recomputed)]),
125 't_dt': np.array([me[0] for me in get_sorted(stats, type='dt', recomputed=recomputed)]),
126 'e_em': np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed))[:, 1],
127 'e_ex': np.array(get_sorted(stats, type='error_extrapolation_estimate', recomputed=recomputed))[:, 1],
128 'restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1],
129 't_restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 0],
130 'sweeps': np.array(get_sorted(stats, type='sweeps', recomputed=None))[:, 1],
131 }
132 data['ready'] = np.logical_and(data['e_ex'] != np.array(None), data['e_em'] != np.array(None))
133 data['restart_times'] = data['t_restarts'][data['restarts'] > 0]
134 return data
137def plot_error(data, ax, use_adaptivity=True, plot_restarts=False):
138 """
139 Plot the embedded and extrapolated error estimates.
141 Args:
142 data (dict): Data prepared from stats by `get_data`
143 use_adaptivity (bool): Whether adaptivity was used
144 plot_restarts (bool): Whether to plot vertical lines for restarts
146 Returns:
147 None
148 """
149 setup_mpl_from_accuracy_check()
150 ax.plot(data['t_dt'], data['dt'], color='black')
152 e_ax = ax.twinx()
153 e_ax.plot(data['t'], data['e_em'], label=r'$\epsilon_\mathrm{embedded}$')
154 e_ax.plot(data['t'][data['ready']], data['e_ex'][data['ready']], label=r'$\epsilon_\mathrm{extrapolated}$', ls='--')
155 e_ax.plot(
156 data['t'][data['ready']],
157 abs(data['e_em'][data['ready']] - data['e_ex'][data['ready']]),
158 label='difference',
159 ls='-.',
160 )
162 if plot_restarts:
163 [ax.axvline(t_restart, ls='-.', color='black', alpha=0.5) for t_restart in data['restart_times']]
165 e_ax.plot([None, None], label=r'$\Delta t$', color='black')
166 e_ax.set_yscale('log')
167 if use_adaptivity:
168 e_ax.legend(frameon=False, loc='upper left')
169 else:
170 e_ax.legend(frameon=False, loc='upper right')
171 e_ax.set_ylim((7.367539795147197e-12, 1.109667868425781e-05))
172 ax.set_ylim((0.012574322653781072, 0.10050387672423527))
174 ax.set_xlabel('Time')
175 ax.set_ylabel(r'$\Delta t$')
176 ax.set_xlabel('Time')
179def setup_mpl_from_accuracy_check():
180 """
181 Change matplotlib parameters to conform to LaTeX style.
182 """
183 from pySDC.projects.Resilience.accuracy_check import setup_mpl
185 setup_mpl()
188def plot_solution(data, ax):
189 """
190 Plot the solution.
192 Args:
193 data (dict): Data prepared from stats by `get_data`
194 ax: Somewhere to plot
196 Returns:
197 None
198 """
199 setup_mpl_from_accuracy_check()
200 ax.plot(data['t'], data['v1'], label='v1', ls='-')
201 ax.plot(data['t'], data['v2'], label='v2', ls='--')
202 ax.plot(data['t'], data['p3'], label='p3', ls='-.')
203 ax.legend(frameon=False)
204 ax.set_xlabel('Time')
207def check_solution(data, use_adaptivity, num_procs, generate_reference=False):
208 """
209 Check the solution against a hard coded reference.
211 Args:
212 data (dict): Data prepared from stats by `get_data`
213 use_adaptivity (bool): Whether adaptivity was used
214 num_procs (int): Number of steps for MSSDC
215 generate_reference (bool): Instead of comparing to reference, print a new reference to the console
217 Returns:
218 None
219 """
220 if use_adaptivity and num_procs == 1:
221 error_msg = 'Error when using adaptivity in serial:'
222 expected = {
223 'v1': 83.88330442715265,
224 'v2': 80.62692930055763,
225 'p3': 16.13594155613822,
226 'e_em': 4.922608098922865e-09,
227 'e_ex': 4.4120077421613226e-08,
228 'dt': 0.05,
229 'restarts': 1.0,
230 'sweeps': 2416.0,
231 't': 20.03656747407325,
232 }
234 elif use_adaptivity and num_procs == 4:
235 error_msg = 'Error when using adaptivity in parallel:'
236 expected = {
237 'v1': 83.88320903115796,
238 'v2': 80.6269822629629,
239 'p3': 16.136084724243805,
240 'e_em': 4.0668446388281154e-09,
241 'e_ex': 4.901094641240463e-09,
242 'dt': 0.05,
243 'restarts': 48.0,
244 'sweeps': 2592.0,
245 't': 20.041499821475185,
246 }
248 elif not use_adaptivity and num_procs == 4:
249 error_msg = 'Error with fixed step size in parallel:'
250 expected = {
251 'v1': 83.88400128006428,
252 'v2': 80.62656202423844,
253 'p3': 16.134849781053525,
254 'e_em': 4.277040943634347e-09,
255 'e_ex': 4.9707053288253756e-09,
256 'dt': 0.05,
257 'restarts': 0.0,
258 'sweeps': 1600.0,
259 't': 20.00000000000015,
260 }
262 elif not use_adaptivity and num_procs == 1:
263 error_msg = 'Error with fixed step size in serial:'
264 expected = {
265 'v1': 83.88400149770143,
266 'v2': 80.62656173487008,
267 'p3': 16.134849851184736,
268 'e_em': 4.977994905175365e-09,
269 'e_ex': 5.048084913047097e-09,
270 'dt': 0.05,
271 'restarts': 0.0,
272 'sweeps': 1600.0,
273 't': 20.00000000000015,
274 }
276 got = {
277 'v1': data['v1'][-1],
278 'v2': data['v2'][-1],
279 'p3': data['p3'][-1],
280 'e_em': data['e_em'][-1],
281 'e_ex': data['e_ex'][data['e_ex'] != [None]][-1],
282 'dt': data['dt'][-1],
283 'restarts': data['restarts'].sum(),
284 'sweeps': data['sweeps'].sum(),
285 't': data['t'][-1],
286 }
288 if generate_reference:
289 print(f'Adaptivity: {use_adaptivity}, num_procs={num_procs}')
290 print('expected = {')
291 for k in got.keys():
292 v = got[k]
293 if type(v) in [list, np.ndarray]:
294 print(f' \'{k}\': {v[v!=[None]][-1]},')
295 else:
296 print(f' \'{k}\': {v},')
297 print('}')
299 for k in expected.keys():
300 assert np.isclose(
301 expected[k], got[k], rtol=1e-4
302 ), f'{error_msg} Expected {k}={expected[k]:.4e}, got {k}={got[k]:.4e}'
305def residual_adaptivity(plot=False):
306 """
307 Make a run with adaptivity based on the residual.
308 """
309 from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityResidual
311 max_res = 1e-8
312 custom_description = {'convergence_controllers': {}}
313 custom_description['convergence_controllers'][AdaptivityResidual] = {
314 'e_tol': max_res,
315 'e_tol_low': max_res / 10,
316 }
317 stats, _, _ = run_piline(custom_description, num_procs=1)
319 residual = get_sorted(stats, type='residual_post_step', recomputed=False)
320 dt = get_sorted(stats, type='dt', recomputed=False)
322 if plot:
323 fig, ax = plt.subplots()
324 dt_ax = ax.twinx()
326 ax.plot([me[0] for me in residual], [me[1] for me in residual])
327 dt_ax.plot([me[0] for me in dt], [me[1] for me in dt], color='black')
328 plt.show()
330 max_residual = max([me[1] for me in residual])
331 assert max_residual < max_res, f'Max. allowed residual is {max_res:.2e}, but got {max_residual:.2e}!'
332 dt_std = np.std([me[1] for me in dt])
333 assert dt_std != 0, f'Expected the step size to change, but standard deviation is {dt_std:.2e}!'
336def main():
337 """
338 Make a variety of tests to see if Hot Rod and Adaptivity work in serial as well as MSSDC.
339 """
340 generate_reference = False
342 for use_adaptivity in [True, False]:
343 custom_description = {'convergence_controllers': {}}
344 if use_adaptivity:
345 custom_description['convergence_controllers'][Adaptivity] = {
346 'e_tol': 1e-7,
347 'embedded_error_flavor': 'linearized',
348 }
350 for num_procs in [1, 4]:
351 custom_description['convergence_controllers'][HotRod] = {'HotRod_tol': 1, 'no_storage': num_procs > 1}
352 stats, _, _ = run_piline(custom_description, num_procs=num_procs)
353 data = get_data(stats, recomputed=False)
354 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))
355 plot_error(data, ax, use_adaptivity)
356 if use_adaptivity:
357 fig.savefig(f'data/piline_hotrod_adaptive_{num_procs}procs.png', bbox_inches='tight', dpi=300)
358 else:
359 fig.savefig(f'data/piline_hotrod_{num_procs}procs.png', bbox_inches='tight', dpi=300)
360 if use_adaptivity and num_procs == 4:
361 sol_fig, sol_ax = plt.subplots(1, 1, figsize=(3.5, 3))
362 plot_solution(data, sol_ax)
363 sol_fig.savefig('data/piline_solution_adaptive.png', bbox_inches='tight', dpi=300)
364 plt.close(sol_fig)
365 check_solution(data, use_adaptivity, num_procs, generate_reference)
366 plt.close(fig)
369if __name__ == "__main__":
370 residual_adaptivity()
371 main()