Coverage for pySDC/projects/Resilience/paper_plots.py: 0%
26 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 16:55 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 16:55 +0000
1# script to make pretty plots for papers or talks
2import numpy as np
3import matplotlib as mpl
4import matplotlib.pyplot as plt
5from pySDC.projects.Resilience.fault_stats import (
6 FaultStats,
7 run_Lorenz,
8 run_Schroedinger,
9 run_vdp,
10 run_quench,
11 run_AC,
12 RECOVERY_THRESH_ABS,
13)
14from pySDC.projects.Resilience.strategies import (
15 BaseStrategy,
16 AdaptivityStrategy,
17 IterateStrategy,
18 HotRodStrategy,
19 DIRKStrategy,
20 ERKStrategy,
21 AdaptivityPolynomialError,
22)
23from pySDC.helpers.plot_helper import setup_mpl, figsize_by_journal
24from pySDC.helpers.stats_helper import get_sorted
27cm = 1 / 2.5
28TEXTWIDTH = 11.9446244611 * cm
29JOURNAL = 'Springer_Numerical_Algorithms'
30BASE_PATH = 'data/paper'
33def get_stats(problem, path='data/stats-jusuf', num_procs=1, strategy_type='SDC'):
34 """
35 Create a FaultStats object for a given problem to use for the plots.
36 Note that the statistics need to be already generated somewhere else, this function will only load them.
38 Args:
39 problem (function): A problem to run
40 path (str): Path to the associated stats for the problem
42 Returns:
43 FaultStats: Object to analyse resilience statistics from
44 """
45 if strategy_type == 'SDC':
46 strategies = [BaseStrategy(), AdaptivityStrategy(), IterateStrategy()]
47 if JOURNAL not in ['JSC_beamer']:
48 strategies += [HotRodStrategy(), AdaptivityPolynomialError()]
49 elif strategy_type == 'RK':
50 strategies = [DIRKStrategy()]
51 if problem.__name__ in ['run_Lorenz', 'run_vdp']:
52 strategies += [ERKStrategy()]
54 stats_analyser = FaultStats(
55 prob=problem,
56 strategies=strategies,
57 faults=[False, True],
58 reload=True,
59 recovery_thresh=1.1,
60 recovery_thresh_abs=RECOVERY_THRESH_ABS.get(problem, 0),
61 mode='default',
62 stats_path=path,
63 num_procs=num_procs,
64 )
65 stats_analyser.get_recovered()
66 return stats_analyser
69def my_setup_mpl(**kwargs):
70 setup_mpl(reset=True, font_size=8)
71 mpl.rcParams.update({'lines.markersize': 6})
74def savefig(fig, name, format='pdf', tight_layout=True): # pragma: no cover
75 """
76 Save a figure to some predefined location.
78 Args:
79 fig (Matplotlib.Figure): The figure of the plot
80 name (str): The name of the plot
81 tight_layout (bool): Apply tight layout or leave as is
82 Returns:
83 None
84 """
85 if tight_layout:
86 fig.tight_layout()
87 path = f'{BASE_PATH}/{name}.{format}'
88 fig.savefig(path, bbox_inches='tight', transparent=True, dpi=200)
89 print(f'saved "{path}"')
92def analyse_resilience(problem, path='data/stats', **kwargs): # pragma: no cover
93 """
94 Generate some stats for resilience / load them if already available and make some plots.
96 Args:
97 problem (function): A problem to run
98 path (str): Path to the associated stats for the problem
100 Returns:
101 None
102 """
104 stats_analyser = get_stats(problem, path)
105 stats_analyser.get_recovered()
107 strategy = IterateStrategy()
108 not_fixed = stats_analyser.get_mask(strategy=strategy, key='recovered', val=False)
109 not_overflow = stats_analyser.get_mask(strategy=strategy, key='bit', val=1, op='uneq', old_mask=not_fixed)
110 stats_analyser.print_faults(not_overflow)
112 compare_strategies(stats_analyser, **kwargs)
113 plot_recovery_rate(stats_analyser, **kwargs)
116def compare_strategies(stats_analyser, **kwargs): # pragma: no cover
117 """
118 Make a plot showing local error and iteration number of time for all strategies
120 Args:
121 stats_analyser (FaultStats): Fault stats object, which contains some stats
123 Returns:
124 None
125 """
126 my_setup_mpl()
127 fig, ax = plt.subplots(figsize=(TEXTWIDTH, 5 * cm))
128 stats_analyser.compare_strategies(ax=ax)
129 savefig(fig, 'compare_strategies', **kwargs)
132def plot_recovery_rate(stats_analyser, **kwargs): # pragma: no cover
133 """
134 Make a plot showing recovery rate for all faults and only for those that can be recovered.
136 Args:
137 stats_analyser (FaultStats): Fault stats object, which contains some stats
139 Returns:
140 None
141 """
142 my_setup_mpl()
143 fig, axs = plt.subplots(1, 2, figsize=(TEXTWIDTH, 5 * cm), sharex=True, sharey=True)
144 stats_analyser.plot_things_per_things(
145 'recovered',
146 'bit',
147 False,
148 op=stats_analyser.rec_rate,
149 args={'ylabel': 'recovery rate'},
150 plotting_args={'markevery': 5},
151 ax=axs[0],
152 )
153 plot_recovery_rate_recoverable_only(stats_analyser, fig, axs[1], ylabel='')
154 axs[0].get_legend().remove()
155 axs[0].set_title('All faults')
156 axs[1].set_title('Only recoverable faults')
157 axs[0].set_ylim((-0.05, 1.05))
158 savefig(fig, 'recovery_rate_compared', **kwargs)
161def plot_recovery_rate_recoverable_only(stats_analyser, fig, ax, **kwargs): # pragma: no cover
162 """
163 Plot the recovery rate considering only faults that can be recovered theoretically.
165 Args:
166 stats_analyser (FaultStats): Fault stats object, which contains some stats
167 fig (matplotlib.pyplot.figure): Figure in which to plot
168 ax (matplotlib.pyplot.axes): Somewhere to plot
170 Returns:
171 None
172 """
173 for i in range(len(stats_analyser.strategies)):
174 fixable = stats_analyser.get_fixable_faults_only(strategy=stats_analyser.strategies[i])
176 stats_analyser.plot_things_per_things(
177 'recovered',
178 'bit',
179 False,
180 op=stats_analyser.rec_rate,
181 mask=fixable,
182 args={**kwargs},
183 ax=ax,
184 fig=fig,
185 strategies=[stats_analyser.strategies[i]],
186 plotting_args={'markevery': 5},
187 )
190def compare_recovery_rate_problems(**kwargs): # pragma: no cover
191 """
192 Compare the recovery rate for vdP, Lorenz and Schroedinger problems.
193 Only faults that can be recovered are shown.
195 Returns:
196 None
197 """
198 stats = [
199 get_stats(run_vdp, **kwargs),
200 get_stats(run_quench, **kwargs),
201 get_stats(run_Schroedinger, **kwargs),
202 get_stats(run_AC, **kwargs),
203 ]
204 titles = ['Van der Pol', 'Quench', r'Schr\"odinger', 'Allen-Cahn']
206 my_setup_mpl()
207 fig, axs = plt.subplots(2, 2, figsize=figsize_by_journal(JOURNAL, 1, 0.8), sharey=True)
208 [
209 plot_recovery_rate_recoverable_only(stats[i], fig, axs.flatten()[i], ylabel='', title=titles[i])
210 for i in range(len(stats))
211 ]
213 for ax in axs.flatten():
214 ax.get_legend().remove()
216 if kwargs.get('strategy_type', 'SDC') == 'SDC':
217 axs[1, 1].legend(frameon=False, loc="lower right")
218 else:
219 axs[0, 1].legend(frameon=False, loc="lower right")
220 axs[0, 0].set_ylim((-0.05, 1.05))
221 axs[1, 0].set_ylabel('recovery rate')
222 axs[0, 0].set_ylabel('recovery rate')
224 name = ''
225 for key, val in kwargs.items():
226 name = f'{name}_{key}-{val}'
228 savefig(fig, f'compare_equations{name}.pdf')
231def plot_adaptivity_stuff(): # pragma: no cover
232 """
233 Plot the solution for a van der Pol problem as well as the local error and cost associated with the base scheme and
234 adaptivity in k and dt in order to demonstrate that adaptivity is useful.
236 Returns:
237 None
238 """
239 from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep
240 from pySDC.implementations.hooks.log_work import LogWork
241 from pySDC.projects.Resilience.hook import LogData
242 import pickle
244 my_setup_mpl()
245 scale = 0.5 if JOURNAL == 'JSC_beamer' else 1.0
246 fig, axs = plt.subplots(3, 1, figsize=figsize_by_journal(JOURNAL, scale, 1), sharex=True, sharey=False)
248 def plot_error(stats, ax, iter_ax, strategy, **kwargs):
249 """
250 Plot global error and cumulative sum of iterations
252 Args:
253 stats (dict): Stats from pySDC run
254 ax (Matplotlib.pyplot.axes): Somewhere to plot the error
255 iter_ax (Matplotlib.pyplot.axes): Somewhere to plot the iterations
256 strategy (pySDC.projects.Resilience.fault_stats.Strategy): The resilience strategy
258 Returns:
259 None
260 """
261 markevery = 1 if type(strategy) in [AdaptivityStrategy, AdaptivityPolynomialError] else 10000
262 e = stats['e_local_post_step']
263 ax.plot([me[0] for me in e], [me[1] for me in e], markevery=markevery, **strategy.style, **kwargs)
264 k = stats['work_newton']
265 iter_ax.plot(
266 [me[0] for me in k], np.cumsum([me[1] for me in k]), **strategy.style, markevery=markevery, **kwargs
267 )
268 ax.set_yscale('log')
269 ax.set_ylabel('local error')
270 iter_ax.set_ylabel(r'Newton iterations')
272 run = False
273 for strategy in [BaseStrategy, IterateStrategy, AdaptivityStrategy, AdaptivityPolynomialError]:
274 S = strategy(newton_inexactness=False)
275 desc = S.get_custom_description(problem=run_vdp, num_procs=1)
276 desc['problem_params']['mu'] = 1000
277 desc['problem_params']['u0'] = (1.1, 0)
278 if strategy in [AdaptivityStrategy, BaseStrategy]:
279 desc['step_params']['maxiter'] = 5
280 if strategy in [BaseStrategy, IterateStrategy]:
281 desc['level_params']['dt'] = 1e-4
282 desc['sweeper_params']['QI'] = 'LU'
283 if strategy in [IterateStrategy]:
284 desc['step_params']['maxiter'] = 99
285 desc['level_params']['restol'] = 1e-10
287 path = f'./data/adaptivity_paper_plot_data_{strategy.__name__}.pickle'
288 if run:
289 stats, _, _ = run_vdp(
290 custom_description=desc,
291 Tend=20,
292 hook_class=[LogLocalErrorPostStep, LogWork, LogData],
293 custom_controller_params={'logger_level': 15},
294 )
296 data = {
297 'u': get_sorted(stats, type='u', recomputed=False),
298 'e_local_post_step': get_sorted(stats, type='e_local_post_step', recomputed=False),
299 'work_newton': get_sorted(stats, type='work_newton', recomputed=None),
300 }
301 with open(path, 'wb') as file:
302 pickle.dump(data, file)
303 else:
304 with open(path, 'rb') as file:
305 data = pickle.load(file)
307 plot_error(data, axs[1], axs[2], strategy())
309 if strategy == BaseStrategy or True:
310 u = data['u']
311 axs[0].plot([me[0] for me in u], [me[1][0] for me in u], color='black', label=r'$u$')
313 axs[2].set_xlabel(r'$t$')
314 axs[0].set_ylabel('solution')
315 axs[2].legend(frameon=JOURNAL == 'JSC_beamer')
316 axs[1].legend(frameon=True)
317 axs[2].set_yscale('log')
318 savefig(fig, 'adaptivity')
321def plot_fault_vdp(bit=0): # pragma: no cover
322 """
323 Make a plot showing the impact of a fault on van der Pol without any resilience.
324 The faults are inserted in the last iteration in the last node in u_t such that you can best see the impact.
326 Args:
327 bit (int): The bit that you want to flip
329 Returns:
330 None
331 """
332 from pySDC.projects.Resilience.fault_stats import (
333 FaultStats,
334 BaseStrategy,
335 )
336 from pySDC.projects.Resilience.hook import LogData
338 stats_analyser = FaultStats(
339 prob=run_vdp,
340 strategies=[BaseStrategy()],
341 faults=[False, True],
342 reload=True,
343 recovery_thresh=1.1,
344 num_procs=1,
345 mode='combination',
346 )
348 my_setup_mpl()
349 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.8, 0.5))
350 colors = ['blue', 'red', 'magenta']
351 ls = ['--', '-']
352 markers = ['*', '^']
353 do_faults = [False, True]
354 superscripts = ['*', '']
355 subscripts = ['', 't', '']
357 run = 779 + 12 * bit # for faults in u_t
358 # run = 11 + 12 * bit # for faults in u
360 for i in range(len(do_faults)):
361 stats, controller, Tend = stats_analyser.single_run(
362 strategy=BaseStrategy(),
363 run=run,
364 faults=do_faults[i],
365 hook_class=[LogData],
366 )
367 u = get_sorted(stats, type='u')
368 faults = get_sorted(stats, type='bitflip')
369 for j in [0, 1]:
370 ax.plot(
371 [me[0] for me in u],
372 [me[1][j] for me in u],
373 ls=ls[i],
374 color=colors[j],
375 label=rf'$u^{ {superscripts[i]}} _{ {subscripts[j]}} $',
376 marker=markers[j],
377 markevery=60,
378 )
379 for idx in range(len(faults)):
380 ax.axvline(faults[idx][0], color='black', label='Fault', ls=':')
381 print(
382 f'Fault at t={faults[idx][0]:.2e}, iter={faults[idx][1][1]}, node={faults[idx][1][2]}, space={faults[idx][1][3]}, bit={faults[idx][1][4]}'
383 )
384 ax.set_title(f'Fault in bit {faults[idx][1][4]}')
386 ax.legend(frameon=True, loc='lower left')
387 ax.set_xlabel(r'$t$')
388 savefig(fig, f'fault_bit_{bit}')
391def plot_quench_solution(): # pragma: no cover
392 """
393 Plot the solution of Quench problem over time
395 Returns:
396 None
397 """
398 my_setup_mpl()
399 if JOURNAL == 'JSC_beamer':
400 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
401 else:
402 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 1.0, 0.45))
404 strategy = BaseStrategy()
406 custom_description = strategy.get_custom_description(run_quench, num_procs=1)
408 stats, controller, _ = run_quench(custom_description=custom_description, Tend=strategy.get_Tend(run_quench))
410 prob = controller.MS[0].levels[0].prob
412 u = get_sorted(stats, type='u', recomputed=False)
414 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color='black', label='$T$')
415 ax.axhline(prob.u_thresh, label=r'$T_\mathrm{thresh}$', ls='--', color='grey', zorder=-1)
416 ax.axhline(prob.u_max, label=r'$T_\mathrm{max}$', ls=':', color='grey', zorder=-1)
418 ax.set_xlabel(r'$t$')
419 ax.legend(frameon=False)
420 savefig(fig, 'quench_sol')
423def plot_Schroedinger_solution(): # pragma: no cover
424 from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import nonlinearschroedinger_imex
426 my_setup_mpl()
427 if JOURNAL == 'JSC_beamer':
428 raise NotImplementedError
429 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
430 else:
431 fig, axs = plt.subplots(1, 2, figsize=figsize_by_journal(JOURNAL, 1.0, 0.45), sharex=True, sharey=True)
433 from mpl_toolkits.axes_grid1 import make_axes_locatable
435 plt.rcParams['figure.constrained_layout.use'] = True
436 cax = []
437 divider = make_axes_locatable(axs[0])
438 cax += [divider.append_axes('right', size='5%', pad=0.05)]
439 divider2 = make_axes_locatable(axs[1])
440 cax += [divider2.append_axes('right', size='5%', pad=0.05)]
442 problem_params = dict()
443 problem_params['nvars'] = (256, 256)
444 problem_params['spectral'] = False
445 problem_params['c'] = 1.0
446 description = {'problem_params': problem_params}
447 stats, _, _ = run_Schroedinger(Tend=1.0e0, custom_description=description)
449 P = nonlinearschroedinger_imex(**problem_params)
450 u = get_sorted(stats, type='u')
452 im = axs[0].pcolormesh(*P.X, np.abs(u[0][1]), rasterized=True)
453 im1 = axs[1].pcolormesh(*P.X, np.abs(u[-1][1]), rasterized=True)
455 fig.colorbar(im, cax=cax[0])
456 fig.colorbar(im1, cax=cax[1])
457 axs[0].set_title(r'$\|u(t=0)\|$')
458 axs[1].set_title(r'$\|u(t=1)\|$')
459 for ax in axs:
460 ax.set_aspect(1)
461 ax.set_xlabel('$x$')
462 ax.set_ylabel('$y$')
463 savefig(fig, 'Schroedinger_sol')
466def plot_AC_solution(): # pragma: no cover
467 from pySDC.projects.Resilience.AC import monitor
469 my_setup_mpl()
470 if JOURNAL == 'JSC_beamer':
471 raise NotImplementedError
472 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
473 else:
474 fig, axs = plt.subplots(1, 2, figsize=figsize_by_journal(JOURNAL, 1.0, 0.45))
476 description = {'problem_params': {'nvars': (256, 256)}}
477 stats, _, _ = run_AC(Tend=0.032, hook_class=monitor, custom_description=description)
479 u = get_sorted(stats, type='u')
481 computed_radius = get_sorted(stats, type='computed_radius')
482 axs[1].plot([me[0] for me in computed_radius], [me[1] for me in computed_radius], ls='-')
483 axs[1].axvline(0.025, ls=':', label=r'$t=0.025$', color='grey')
484 axs[1].set_title('Radius over time')
485 axs[1].set_xlabel('$t$')
486 axs[1].legend(frameon=False)
488 im = axs[0].imshow(u[0][1], extent=(-0.5, 0.5, -0.5, 0.5))
489 fig.colorbar(im)
490 axs[0].set_title(r'$u_0$')
491 axs[0].set_xlabel('$x$')
492 axs[0].set_ylabel('$y$')
493 savefig(fig, 'AC_sol')
496def plot_vdp_solution(): # pragma: no cover
497 """
498 Plot the solution of van der Pol problem over time to illustrate the varying time scales.
500 Returns:
501 None
502 """
503 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
505 my_setup_mpl()
506 if JOURNAL == 'JSC_beamer':
507 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
508 else:
509 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 1.0, 0.33))
511 custom_description = {
512 'convergence_controllers': {Adaptivity: {'e_tol': 1e-7, 'dt_max': 1e0}},
513 'problem_params': {'mu': 1000, 'crash_at_maxiter': False},
514 'level_params': {'dt': 1e-3},
515 }
517 stats, _, _ = run_vdp(custom_description=custom_description, Tend=2000)
519 u = get_sorted(stats, type='u', recomputed=False)
520 _u = np.array([me[1][0] for me in u])
521 _x = np.array([me[0] for me in u])
523 x1 = _x[abs(_u - 1.1) < 1e-2][0]
524 ax.plot(_x, _u, color='black')
525 ax.axvspan(x1, x1 + 20, alpha=0.4)
526 ax.set_ylabel(r'$u$')
527 ax.set_xlabel(r'$t$')
528 savefig(fig, 'vdp_sol')
531def work_precision(): # pragma: no cover
532 from pySDC.projects.Resilience.work_precision import (
533 all_problems,
534 )
536 all_params = {
537 'record': False,
538 'work_key': 't',
539 'precision_key': 'e_global_rel',
540 'plotting': True,
541 'base_path': 'data/paper',
542 }
544 for mode in ['compare_strategies', 'parallel_efficiency', 'RK_comp']:
545 all_problems(**all_params, mode=mode)
546 all_problems(**{**all_params, 'work_key': 'param'}, mode='compare_strategies')
549def make_plots_for_TIME_X_website(): # pragma: no cover
550 global JOURNAL, BASE_PATH
551 JOURNAL = 'JSC_beamer'
552 BASE_PATH = 'data/paper/time-x_website'
554 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 2.0 / 3.0))
555 plot_recovery_rate_recoverable_only(get_stats(run_vdp), fig, ax)
556 savefig(fig, 'recovery_rate', format='png')
558 from pySDC.projects.Resilience.work_precision import vdp_stiffness_plot
560 vdp_stiffness_plot(base_path=BASE_PATH, format='png')
563def make_plots_for_SIAM_CSE23(): # pragma: no cover
564 """
565 Make plots for the SIAM talk
566 """
567 global JOURNAL, BASE_PATH
568 JOURNAL = 'JSC_beamer'
569 BASE_PATH = 'data/paper/SIAMCSE23'
571 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 3.0 / 4.0))
572 plot_recovery_rate_recoverable_only(get_stats(run_vdp), fig, ax)
573 savefig(fig, 'recovery_rate')
575 plot_adaptivity_stuff()
576 compare_recovery_rate_problems()
577 plot_vdp_solution()
580def make_plots_for_adaptivity_paper(): # pragma: no cover
581 """
582 Make plots that are supposed to go in the paper.
583 """
584 global JOURNAL, BASE_PATH
585 JOURNAL = 'Springer_Numerical_Algorithms'
586 BASE_PATH = 'data/paper'
588 plot_adaptivity_stuff()
590 work_precision()
592 plot_vdp_solution()
593 plot_AC_solution()
594 plot_Schroedinger_solution()
595 plot_quench_solution()
598def make_plots_for_resilience_paper(): # pragma: no cover
599 plot_recovery_rate(get_stats(run_vdp))
600 plot_fault_vdp(0)
601 plot_fault_vdp(13)
602 compare_recovery_rate_problems(num_procs=1, strategy_type='SDC')
605def make_plots_for_notes(): # pragma: no cover
606 """
607 Make plots for the notes for the website / GitHub
608 """
609 global JOURNAL, BASE_PATH
610 JOURNAL = 'Springer_Numerical_Algorithms'
611 BASE_PATH = 'notes/Lorenz'
613 analyse_resilience(run_Lorenz, format='png')
614 analyse_resilience(run_quench, format='png')
617if __name__ == "__main__":
618 # make_plots_for_notes()
619 # make_plots_for_SIAM_CSE23()
620 # make_plots_for_TIME_X_website()
621 make_plots_for_adaptivity_paper()