Coverage for pySDC/projects/Resilience/paper_plots.py: 0%
26 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1# script to make pretty plots for papers or talks
2import numpy as np
3import matplotlib as mpl
4import matplotlib.pyplot as plt
5from pySDC.projects.Resilience.fault_stats import (
6 FaultStats,
7 run_Lorenz,
8 run_Schroedinger,
9 run_vdp,
10 run_quench,
11 run_AC,
12 RECOVERY_THRESH_ABS,
13)
14from pySDC.projects.Resilience.strategies import (
15 BaseStrategy,
16 AdaptivityStrategy,
17 IterateStrategy,
18 HotRodStrategy,
19 DIRKStrategy,
20 ERKStrategy,
21 AdaptivityPolynomialError,
22)
23from pySDC.helpers.plot_helper import setup_mpl, figsize_by_journal
24from pySDC.helpers.stats_helper import get_sorted
27cm = 1 / 2.5
28TEXTWIDTH = 11.9446244611 * cm
29JOURNAL = 'Springer_Numerical_Algorithms'
30BASE_PATH = 'data/paper'
33def get_stats(problem, path='data/stats-jusuf', num_procs=1, strategy_type='SDC'):
34 """
35 Create a FaultStats object for a given problem to use for the plots.
36 Note that the statistics need to be already generated somewhere else, this function will only load them.
38 Args:
39 problem (function): A problem to run
40 path (str): Path to the associated stats for the problem
42 Returns:
43 FaultStats: Object to analyse resilience statistics from
44 """
45 if strategy_type == 'SDC':
46 strategies = [BaseStrategy(), AdaptivityStrategy(), IterateStrategy()]
47 if JOURNAL not in ['JSC_beamer']:
48 strategies += [HotRodStrategy(), AdaptivityPolynomialError()]
49 elif strategy_type == 'RK':
50 strategies = [DIRKStrategy()]
51 if problem.__name__ in ['run_Lorenz', 'run_vdp']:
52 strategies += [ERKStrategy()]
54 stats_analyser = FaultStats(
55 prob=problem,
56 strategies=strategies,
57 faults=[False, True],
58 reload=True,
59 recovery_thresh=1.1,
60 recovery_thresh_abs=RECOVERY_THRESH_ABS.get(problem, 0),
61 mode='default',
62 stats_path=path,
63 num_procs=num_procs,
64 )
65 stats_analyser.get_recovered()
66 return stats_analyser
69def my_setup_mpl(**kwargs):
70 setup_mpl(reset=True, font_size=8)
71 mpl.rcParams.update({'lines.markersize': 6})
74def savefig(fig, name, format='pdf', tight_layout=True): # pragma: no cover
75 """
76 Save a figure to some predefined location.
78 Args:
79 fig (Matplotlib.Figure): The figure of the plot
80 name (str): The name of the plot
81 tight_layout (bool): Apply tight layout or leave as is
82 Returns:
83 None
84 """
85 if tight_layout:
86 fig.tight_layout()
87 path = f'{BASE_PATH}/{name}.{format}'
88 fig.savefig(path, bbox_inches='tight', transparent=True, dpi=200)
89 print(f'saved "{path}"')
92def analyse_resilience(problem, path='data/stats', **kwargs): # pragma: no cover
93 """
94 Generate some stats for resilience / load them if already available and make some plots.
96 Args:
97 problem (function): A problem to run
98 path (str): Path to the associated stats for the problem
100 Returns:
101 None
102 """
104 stats_analyser = get_stats(problem, path)
105 stats_analyser.get_recovered()
107 strategy = IterateStrategy()
108 not_fixed = stats_analyser.get_mask(strategy=strategy, key='recovered', val=False)
109 not_overflow = stats_analyser.get_mask(strategy=strategy, key='bit', val=1, op='uneq', old_mask=not_fixed)
110 stats_analyser.print_faults(not_overflow)
112 compare_strategies(stats_analyser, **kwargs)
113 plot_recovery_rate(stats_analyser, **kwargs)
116def compare_strategies(stats_analyser, **kwargs): # pragma: no cover
117 """
118 Make a plot showing local error and iteration number of time for all strategies
120 Args:
121 stats_analyser (FaultStats): Fault stats object, which contains some stats
123 Returns:
124 None
125 """
126 my_setup_mpl()
127 fig, ax = plt.subplots(figsize=(TEXTWIDTH, 5 * cm))
128 stats_analyser.compare_strategies(ax=ax)
129 savefig(fig, 'compare_strategies', **kwargs)
132def plot_recovery_rate(stats_analyser, **kwargs): # pragma: no cover
133 """
134 Make a plot showing recovery rate for all faults and only for those that can be recovered.
136 Args:
137 stats_analyser (FaultStats): Fault stats object, which contains some stats
139 Returns:
140 None
141 """
142 my_setup_mpl()
143 fig, axs = plt.subplots(1, 2, figsize=(TEXTWIDTH, 5 * cm), sharex=True, sharey=True)
144 stats_analyser.plot_things_per_things(
145 'recovered',
146 'bit',
147 False,
148 op=stats_analyser.rec_rate,
149 args={'ylabel': 'recovery rate'},
150 plotting_args={'markevery': 5},
151 ax=axs[0],
152 )
153 plot_recovery_rate_recoverable_only(stats_analyser, fig, axs[1], ylabel='')
154 axs[0].get_legend().remove()
155 axs[0].set_title('All faults')
156 axs[1].set_title('Only recoverable faults')
157 axs[0].set_ylim((-0.05, 1.05))
158 savefig(fig, 'recovery_rate_compared', **kwargs)
161def plot_recovery_rate_recoverable_only(stats_analyser, fig, ax, **kwargs): # pragma: no cover
162 """
163 Plot the recovery rate considering only faults that can be recovered theoretically.
165 Args:
166 stats_analyser (FaultStats): Fault stats object, which contains some stats
167 fig (matplotlib.pyplot.figure): Figure in which to plot
168 ax (matplotlib.pyplot.axes): Somewhere to plot
170 Returns:
171 None
172 """
173 for i in range(len(stats_analyser.strategies)):
174 fixable = stats_analyser.get_fixable_faults_only(strategy=stats_analyser.strategies[i])
176 stats_analyser.plot_things_per_things(
177 'recovered',
178 'bit',
179 False,
180 op=stats_analyser.rec_rate,
181 mask=fixable,
182 args={**kwargs},
183 ax=ax,
184 fig=fig,
185 strategies=[stats_analyser.strategies[i]],
186 plotting_args={'markevery': 5},
187 )
190def compare_recovery_rate_problems(**kwargs): # pragma: no cover
191 """
192 Compare the recovery rate for vdP, Lorenz and Schroedinger problems.
193 Only faults that can be recovered are shown.
195 Returns:
196 None
197 """
198 stats = [
199 get_stats(run_vdp, **kwargs),
200 get_stats(run_quench, **kwargs),
201 get_stats(run_Schroedinger, **kwargs),
202 get_stats(run_AC, **kwargs),
203 ]
204 titles = ['Van der Pol', 'Quench', r'Schr\"odinger', 'Allen-Cahn']
206 my_setup_mpl()
207 fig, axs = plt.subplots(2, 2, figsize=figsize_by_journal(JOURNAL, 1, 0.8), sharey=True)
208 [
209 plot_recovery_rate_recoverable_only(stats[i], fig, axs.flatten()[i], ylabel='', title=titles[i])
210 for i in range(len(stats))
211 ]
213 for ax in axs.flatten():
214 ax.get_legend().remove()
216 if kwargs.get('strategy_type', 'SDC') == 'SDC':
217 axs[1, 1].legend(frameon=False, loc="lower right")
218 else:
219 axs[0, 1].legend(frameon=False, loc="lower right")
220 axs[0, 0].set_ylim((-0.05, 1.05))
221 axs[1, 0].set_ylabel('recovery rate')
222 axs[0, 0].set_ylabel('recovery rate')
224 name = ''
225 for key, val in kwargs.items():
226 name = f'{name}_{key}-{val}'
228 savefig(fig, f'compare_equations{name}.pdf')
231def plot_adaptivity_stuff(): # pragma: no cover
232 """
233 Plot the solution for a van der Pol problem as well as the local error and cost associated with the base scheme and
234 adaptivity in k and dt in order to demonstrate that adaptivity is useful.
236 Returns:
237 None
238 """
239 from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep
240 from pySDC.implementations.hooks.log_work import LogWork
241 from pySDC.projects.Resilience.hook import LogData
243 stats_analyser = get_stats(run_vdp, 'data/stats')
245 my_setup_mpl()
246 scale = 0.5 if JOURNAL == 'JSC_beamer' else 1.0
247 fig, axs = plt.subplots(3, 1, figsize=figsize_by_journal(JOURNAL, scale, 1), sharex=True, sharey=False)
249 def plot_error(stats, ax, iter_ax, strategy, **kwargs):
250 """
251 Plot global error and cumulative sum of iterations
253 Args:
254 stats (dict): Stats from pySDC run
255 ax (Matplotlib.pyplot.axes): Somewhere to plot the error
256 iter_ax (Matplotlib.pyplot.axes): Somewhere to plot the iterations
257 strategy (pySDC.projects.Resilience.fault_stats.Strategy): The resilience strategy
259 Returns:
260 None
261 """
262 markevery = 40
263 e = get_sorted(stats, type='e_local_post_step', recomputed=False)
264 ax.plot([me[0] for me in e], [me[1] for me in e], markevery=markevery, **strategy.style, **kwargs)
265 k = get_sorted(stats, type='work_newton')
266 iter_ax.plot(
267 [me[0] for me in k], np.cumsum([me[1] for me in k]), **strategy.style, markevery=markevery, **kwargs
268 )
269 ax.set_yscale('log')
270 ax.set_ylabel('local error')
271 iter_ax.set_ylabel(r'Newton iterations')
273 force_params = {}
274 for strategy in [BaseStrategy, AdaptivityStrategy, IterateStrategy, AdaptivityPolynomialError]:
275 if strategy == AdaptivityPolynomialError:
276 from pySDC.implementations.convergence_controller_classes.adaptivity import (
277 AdaptivityPolynomialError as adaptivity,
278 )
280 force_params = {'sweeper_params': {'num_nodes': 2}}
281 force_params['convergence_controllers'] = {
282 adaptivity: {
283 'e_tol': 7e-5,
284 'restol_rel': 1e-4,
285 'restol_min': 1e-10,
286 'restart_at_maxiter': True,
287 'factor_if_not_converged': 4.0,
288 },
289 }
290 else:
291 force_params = {}
292 stats, _, _ = stats_analyser.single_run(
293 strategy=strategy(useMPI=False),
294 force_params=force_params,
295 hook_class=[LogLocalErrorPostStep, LogData, LogWork],
296 )
297 plot_error(stats, axs[1], axs[2], strategy())
299 if strategy == BaseStrategy:
300 u = get_sorted(stats, type='u', recomputed=False)
301 axs[0].plot([me[0] for me in u], [me[1][0] for me in u], color='black', label=r'$u$')
303 axs[2].set_xlabel(r'$t$')
304 axs[0].set_ylabel('solution')
305 axs[2].legend(frameon=JOURNAL == 'JSC_beamer')
306 axs[1].legend(frameon=True)
307 savefig(fig, 'adaptivity')
310def plot_fault_vdp(bit=0): # pragma: no cover
311 """
312 Make a plot showing the impact of a fault on van der Pol without any resilience.
313 The faults are inserted in the last iteration in the last node in u_t such that you can best see the impact.
315 Args:
316 bit (int): The bit that you want to flip
318 Returns:
319 None
320 """
321 from pySDC.projects.Resilience.fault_stats import (
322 FaultStats,
323 BaseStrategy,
324 )
325 from pySDC.projects.Resilience.hook import LogData
327 stats_analyser = FaultStats(
328 prob=run_vdp,
329 strategies=[BaseStrategy()],
330 faults=[False, True],
331 reload=True,
332 recovery_thresh=1.1,
333 num_procs=1,
334 mode='combination',
335 )
337 my_setup_mpl()
338 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.8, 0.5))
339 colors = ['blue', 'red', 'magenta']
340 ls = ['--', '-']
341 markers = ['*', '^']
342 do_faults = [False, True]
343 superscripts = ['*', '']
344 subscripts = ['', 't', '']
346 run = 779 + 12 * bit # for faults in u_t
347 # run = 11 + 12 * bit # for faults in u
349 for i in range(len(do_faults)):
350 stats, controller, Tend = stats_analyser.single_run(
351 strategy=BaseStrategy(),
352 run=run,
353 faults=do_faults[i],
354 hook_class=[LogData],
355 )
356 u = get_sorted(stats, type='u')
357 faults = get_sorted(stats, type='bitflip')
358 for j in [0, 1]:
359 ax.plot(
360 [me[0] for me in u],
361 [me[1][j] for me in u],
362 ls=ls[i],
363 color=colors[j],
364 label=rf'$u^{ {superscripts[i]}} _{ {subscripts[j]}} $',
365 marker=markers[j],
366 markevery=60,
367 )
368 for idx in range(len(faults)):
369 ax.axvline(faults[idx][0], color='black', label='Fault', ls=':')
370 print(
371 f'Fault at t={faults[idx][0]:.2e}, iter={faults[idx][1][1]}, node={faults[idx][1][2]}, space={faults[idx][1][3]}, bit={faults[idx][1][4]}'
372 )
373 ax.set_title(f'Fault in bit {faults[idx][1][4]}')
375 ax.legend(frameon=True, loc='lower left')
376 ax.set_xlabel(r'$t$')
377 savefig(fig, f'fault_bit_{bit}')
380def plot_quench_solution(): # pragma: no cover
381 """
382 Plot the solution of Quench problem over time
384 Returns:
385 None
386 """
387 my_setup_mpl()
388 if JOURNAL == 'JSC_beamer':
389 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
390 else:
391 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 1.0, 0.45))
393 strategy = BaseStrategy()
395 custom_description = strategy.get_custom_description(run_quench, num_procs=1)
397 stats, controller, _ = run_quench(custom_description=custom_description, Tend=strategy.get_Tend(run_quench))
399 prob = controller.MS[0].levels[0].prob
401 u = get_sorted(stats, type='u', recomputed=False)
403 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color='black', label='$T$')
404 ax.axhline(prob.u_thresh, label=r'$T_\mathrm{thresh}$', ls='--', color='grey', zorder=-1)
405 ax.axhline(prob.u_max, label=r'$T_\mathrm{max}$', ls=':', color='grey', zorder=-1)
407 ax.set_xlabel(r'$t$')
408 ax.legend(frameon=False)
409 savefig(fig, 'quench_sol')
412def plot_AC_solution(): # pragma: no cover
413 from pySDC.projects.TOMS.AllenCahn_monitor import monitor
415 my_setup_mpl()
416 if JOURNAL == 'JSC_beamer':
417 raise NotImplementedError
418 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
419 else:
420 fig, axs = plt.subplots(1, 2, figsize=figsize_by_journal(JOURNAL, 1.0, 0.45))
422 stats, _, _ = run_AC(Tend=0.032, hook_class=monitor)
424 u = get_sorted(stats, type='u')
426 computed_radius = get_sorted(stats, type='computed_radius')
427 exact_radius = get_sorted(stats, type='exact_radius')
428 axs[1].plot([me[0] for me in computed_radius], [me[1] for me in computed_radius], ls='-', label='numerical')
429 axs[1].plot([me[0] for me in exact_radius], [me[1] for me in exact_radius], ls='--', color='black', label='exact')
430 axs[1].axvline(0.025, ls=':', label=r'$t=0.025$', color='grey')
431 axs[1].set_title('Radius over time')
432 axs[1].set_xlabel('$t$')
433 axs[1].legend(frameon=False)
435 im = axs[0].imshow(u[0][1], extent=(-0.5, 0.5, -0.5, 0.5))
436 fig.colorbar(im)
437 axs[0].set_title(r'$u_0$')
438 axs[0].set_xlabel('$x$')
439 axs[0].set_ylabel('$y$')
440 savefig(fig, 'AC_sol')
443def plot_vdp_solution(): # pragma: no cover
444 """
445 Plot the solution of van der Pol problem over time to illustrate the varying time scales.
447 Returns:
448 None
449 """
450 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
452 my_setup_mpl()
453 if JOURNAL == 'JSC_beamer':
454 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9))
455 else:
456 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 1.0, 0.33))
458 custom_description = {'convergence_controllers': {Adaptivity: {'e_tol': 1e-7}}}
460 stats, _, _ = run_vdp(custom_description=custom_description, Tend=28.6)
462 u = get_sorted(stats, type='u')
463 ax.plot([me[0] for me in u], [me[1][0] for me in u], color='black')
464 ax.set_ylabel(r'$u$')
465 ax.set_xlabel(r'$t$')
466 savefig(fig, 'vdp_sol')
469def work_precision(): # pragma: no cover
470 from pySDC.projects.Resilience.work_precision import (
471 all_problems,
472 )
474 all_params = {
475 'record': False,
476 'work_key': 't',
477 'precision_key': 'e_global_rel',
478 'plotting': True,
479 'base_path': 'data/paper',
480 }
482 for mode in ['compare_strategies', 'parallel_efficiency', 'RK_comp']:
483 all_problems(**all_params, mode=mode)
486def make_plots_for_TIME_X_website(): # pragma: no cover
487 global JOURNAL, BASE_PATH
488 JOURNAL = 'JSC_beamer'
489 BASE_PATH = 'data/paper/time-x_website'
491 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 2.0 / 3.0))
492 plot_recovery_rate_recoverable_only(get_stats(run_vdp), fig, ax)
493 savefig(fig, 'recovery_rate', format='png')
495 from pySDC.projects.Resilience.work_precision import vdp_stiffness_plot
497 vdp_stiffness_plot(base_path=BASE_PATH, format='png')
500def make_plots_for_SIAM_CSE23(): # pragma: no cover
501 """
502 Make plots for the SIAM talk
503 """
504 global JOURNAL, BASE_PATH
505 JOURNAL = 'JSC_beamer'
506 BASE_PATH = 'data/paper/SIAMCSE23'
508 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 3.0 / 4.0))
509 plot_recovery_rate_recoverable_only(get_stats(run_vdp), fig, ax)
510 savefig(fig, 'recovery_rate')
512 plot_adaptivity_stuff()
513 compare_recovery_rate_problems()
514 plot_vdp_solution()
517def make_plots_for_paper(): # pragma: no cover
518 """
519 Make plots that are supposed to go in the paper.
520 """
521 global JOURNAL, BASE_PATH
522 JOURNAL = 'Springer_Numerical_Algorithms'
523 BASE_PATH = 'data/paper'
525 plot_adaptivity_stuff()
527 work_precision()
529 plot_vdp_solution()
530 plot_AC_solution()
531 plot_quench_solution()
533 plot_recovery_rate(get_stats(run_vdp))
534 plot_fault_vdp(0)
535 plot_fault_vdp(13)
536 compare_recovery_rate_problems(num_procs=1, strategy_type='SDC')
539def make_plots_for_notes(): # pragma: no cover
540 """
541 Make plots for the notes for the website / GitHub
542 """
543 global JOURNAL, BASE_PATH
544 JOURNAL = 'Springer_Numerical_Algorithms'
545 BASE_PATH = 'notes/Lorenz'
547 analyse_resilience(run_Lorenz, format='png')
548 analyse_resilience(run_quench, format='png')
551if __name__ == "__main__":
552 # make_plots_for_notes()
553 # make_plots_for_SIAM_CSE23()
554 # make_plots_for_TIME_X_website()
555 make_plots_for_paper()