Coverage for pySDC/projects/Resilience/paper_plots.py: 0%

26 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1# script to make pretty plots for papers or talks 

2import numpy as np 

3import matplotlib as mpl 

4import matplotlib.pyplot as plt 

5from pySDC.projects.Resilience.fault_stats import ( 

6 FaultStats, 

7 run_Lorenz, 

8 run_Schroedinger, 

9 run_vdp, 

10 run_quench, 

11 run_AC, 

12 RECOVERY_THRESH_ABS, 

13) 

14from pySDC.projects.Resilience.strategies import ( 

15 BaseStrategy, 

16 AdaptivityStrategy, 

17 IterateStrategy, 

18 HotRodStrategy, 

19 DIRKStrategy, 

20 ERKStrategy, 

21 AdaptivityPolynomialError, 

22) 

23from pySDC.helpers.plot_helper import setup_mpl, figsize_by_journal 

24from pySDC.helpers.stats_helper import get_sorted 

25 

26 

27cm = 1 / 2.5 

28TEXTWIDTH = 11.9446244611 * cm 

29JOURNAL = 'Springer_Numerical_Algorithms' 

30BASE_PATH = 'data/paper' 

31 

32 

33def get_stats(problem, path='data/stats-jusuf', num_procs=1, strategy_type='SDC'): 

34 """ 

35 Create a FaultStats object for a given problem to use for the plots. 

36 Note that the statistics need to be already generated somewhere else, this function will only load them. 

37 

38 Args: 

39 problem (function): A problem to run 

40 path (str): Path to the associated stats for the problem 

41 

42 Returns: 

43 FaultStats: Object to analyse resilience statistics from 

44 """ 

45 if strategy_type == 'SDC': 

46 strategies = [BaseStrategy(), AdaptivityStrategy(), IterateStrategy()] 

47 if JOURNAL not in ['JSC_beamer']: 

48 strategies += [HotRodStrategy(), AdaptivityPolynomialError()] 

49 elif strategy_type == 'RK': 

50 strategies = [DIRKStrategy()] 

51 if problem.__name__ in ['run_Lorenz', 'run_vdp']: 

52 strategies += [ERKStrategy()] 

53 

54 stats_analyser = FaultStats( 

55 prob=problem, 

56 strategies=strategies, 

57 faults=[False, True], 

58 reload=True, 

59 recovery_thresh=1.1, 

60 recovery_thresh_abs=RECOVERY_THRESH_ABS.get(problem, 0), 

61 mode='default', 

62 stats_path=path, 

63 num_procs=num_procs, 

64 ) 

65 stats_analyser.get_recovered() 

66 return stats_analyser 

67 

68 

69def my_setup_mpl(**kwargs): 

70 setup_mpl(reset=True, font_size=8) 

71 mpl.rcParams.update({'lines.markersize': 6}) 

72 

73 

74def savefig(fig, name, format='pdf', tight_layout=True): # pragma: no cover 

75 """ 

76 Save a figure to some predefined location. 

77 

78 Args: 

79 fig (Matplotlib.Figure): The figure of the plot 

80 name (str): The name of the plot 

81 tight_layout (bool): Apply tight layout or leave as is 

82 Returns: 

83 None 

84 """ 

85 if tight_layout: 

86 fig.tight_layout() 

87 path = f'{BASE_PATH}/{name}.{format}' 

88 fig.savefig(path, bbox_inches='tight', transparent=True, dpi=200) 

89 print(f'saved "{path}"') 

90 

91 

92def analyse_resilience(problem, path='data/stats', **kwargs): # pragma: no cover 

93 """ 

94 Generate some stats for resilience / load them if already available and make some plots. 

95 

96 Args: 

97 problem (function): A problem to run 

98 path (str): Path to the associated stats for the problem 

99 

100 Returns: 

101 None 

102 """ 

103 

104 stats_analyser = get_stats(problem, path) 

105 stats_analyser.get_recovered() 

106 

107 strategy = IterateStrategy() 

108 not_fixed = stats_analyser.get_mask(strategy=strategy, key='recovered', val=False) 

109 not_overflow = stats_analyser.get_mask(strategy=strategy, key='bit', val=1, op='uneq', old_mask=not_fixed) 

110 stats_analyser.print_faults(not_overflow) 

111 

112 compare_strategies(stats_analyser, **kwargs) 

113 plot_recovery_rate(stats_analyser, **kwargs) 

114 

115 

116def compare_strategies(stats_analyser, **kwargs): # pragma: no cover 

117 """ 

118 Make a plot showing local error and iteration number of time for all strategies 

119 

120 Args: 

121 stats_analyser (FaultStats): Fault stats object, which contains some stats 

122 

123 Returns: 

124 None 

125 """ 

126 my_setup_mpl() 

127 fig, ax = plt.subplots(figsize=(TEXTWIDTH, 5 * cm)) 

128 stats_analyser.compare_strategies(ax=ax) 

129 savefig(fig, 'compare_strategies', **kwargs) 

130 

131 

132def plot_recovery_rate(stats_analyser, **kwargs): # pragma: no cover 

133 """ 

134 Make a plot showing recovery rate for all faults and only for those that can be recovered. 

135 

136 Args: 

137 stats_analyser (FaultStats): Fault stats object, which contains some stats 

138 

139 Returns: 

140 None 

141 """ 

142 my_setup_mpl() 

143 fig, axs = plt.subplots(1, 2, figsize=(TEXTWIDTH, 5 * cm), sharex=True, sharey=True) 

144 stats_analyser.plot_things_per_things( 

145 'recovered', 

146 'bit', 

147 False, 

148 op=stats_analyser.rec_rate, 

149 args={'ylabel': 'recovery rate'}, 

150 plotting_args={'markevery': 5}, 

151 ax=axs[0], 

152 ) 

153 plot_recovery_rate_recoverable_only(stats_analyser, fig, axs[1], ylabel='') 

154 axs[0].get_legend().remove() 

155 axs[0].set_title('All faults') 

156 axs[1].set_title('Only recoverable faults') 

157 axs[0].set_ylim((-0.05, 1.05)) 

158 savefig(fig, 'recovery_rate_compared', **kwargs) 

159 

160 

161def plot_recovery_rate_recoverable_only(stats_analyser, fig, ax, **kwargs): # pragma: no cover 

162 """ 

163 Plot the recovery rate considering only faults that can be recovered theoretically. 

164 

165 Args: 

166 stats_analyser (FaultStats): Fault stats object, which contains some stats 

167 fig (matplotlib.pyplot.figure): Figure in which to plot 

168 ax (matplotlib.pyplot.axes): Somewhere to plot 

169 

170 Returns: 

171 None 

172 """ 

173 for i in range(len(stats_analyser.strategies)): 

174 fixable = stats_analyser.get_fixable_faults_only(strategy=stats_analyser.strategies[i]) 

175 

176 stats_analyser.plot_things_per_things( 

177 'recovered', 

178 'bit', 

179 False, 

180 op=stats_analyser.rec_rate, 

181 mask=fixable, 

182 args={**kwargs}, 

183 ax=ax, 

184 fig=fig, 

185 strategies=[stats_analyser.strategies[i]], 

186 plotting_args={'markevery': 5}, 

187 ) 

188 

189 

190def compare_recovery_rate_problems(**kwargs): # pragma: no cover 

191 """ 

192 Compare the recovery rate for vdP, Lorenz and Schroedinger problems. 

193 Only faults that can be recovered are shown. 

194 

195 Returns: 

196 None 

197 """ 

198 stats = [ 

199 get_stats(run_vdp, **kwargs), 

200 get_stats(run_quench, **kwargs), 

201 get_stats(run_Schroedinger, **kwargs), 

202 get_stats(run_AC, **kwargs), 

203 ] 

204 titles = ['Van der Pol', 'Quench', r'Schr\"odinger', 'Allen-Cahn'] 

205 

206 my_setup_mpl() 

207 fig, axs = plt.subplots(2, 2, figsize=figsize_by_journal(JOURNAL, 1, 0.8), sharey=True) 

208 [ 

209 plot_recovery_rate_recoverable_only(stats[i], fig, axs.flatten()[i], ylabel='', title=titles[i]) 

210 for i in range(len(stats)) 

211 ] 

212 

213 for ax in axs.flatten(): 

214 ax.get_legend().remove() 

215 

216 if kwargs.get('strategy_type', 'SDC') == 'SDC': 

217 axs[1, 1].legend(frameon=False, loc="lower right") 

218 else: 

219 axs[0, 1].legend(frameon=False, loc="lower right") 

220 axs[0, 0].set_ylim((-0.05, 1.05)) 

221 axs[1, 0].set_ylabel('recovery rate') 

222 axs[0, 0].set_ylabel('recovery rate') 

223 

224 name = '' 

225 for key, val in kwargs.items(): 

226 name = f'{name}_{key}-{val}' 

227 

228 savefig(fig, f'compare_equations{name}.pdf') 

229 

230 

231def plot_adaptivity_stuff(): # pragma: no cover 

232 """ 

233 Plot the solution for a van der Pol problem as well as the local error and cost associated with the base scheme and 

234 adaptivity in k and dt in order to demonstrate that adaptivity is useful. 

235 

236 Returns: 

237 None 

238 """ 

239 from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep 

240 from pySDC.implementations.hooks.log_work import LogWork 

241 from pySDC.projects.Resilience.hook import LogData 

242 

243 stats_analyser = get_stats(run_vdp, 'data/stats') 

244 

245 my_setup_mpl() 

246 scale = 0.5 if JOURNAL == 'JSC_beamer' else 1.0 

247 fig, axs = plt.subplots(3, 1, figsize=figsize_by_journal(JOURNAL, scale, 1), sharex=True, sharey=False) 

248 

249 def plot_error(stats, ax, iter_ax, strategy, **kwargs): 

250 """ 

251 Plot global error and cumulative sum of iterations 

252 

253 Args: 

254 stats (dict): Stats from pySDC run 

255 ax (Matplotlib.pyplot.axes): Somewhere to plot the error 

256 iter_ax (Matplotlib.pyplot.axes): Somewhere to plot the iterations 

257 strategy (pySDC.projects.Resilience.fault_stats.Strategy): The resilience strategy 

258 

259 Returns: 

260 None 

261 """ 

262 markevery = 40 

263 e = get_sorted(stats, type='e_local_post_step', recomputed=False) 

264 ax.plot([me[0] for me in e], [me[1] for me in e], markevery=markevery, **strategy.style, **kwargs) 

265 k = get_sorted(stats, type='work_newton') 

266 iter_ax.plot( 

267 [me[0] for me in k], np.cumsum([me[1] for me in k]), **strategy.style, markevery=markevery, **kwargs 

268 ) 

269 ax.set_yscale('log') 

270 ax.set_ylabel('local error') 

271 iter_ax.set_ylabel(r'Newton iterations') 

272 

273 force_params = {} 

274 for strategy in [BaseStrategy, AdaptivityStrategy, IterateStrategy, AdaptivityPolynomialError]: 

275 if strategy == AdaptivityPolynomialError: 

276 from pySDC.implementations.convergence_controller_classes.adaptivity import ( 

277 AdaptivityPolynomialError as adaptivity, 

278 ) 

279 

280 force_params = {'sweeper_params': {'num_nodes': 2}} 

281 force_params['convergence_controllers'] = { 

282 adaptivity: { 

283 'e_tol': 7e-5, 

284 'restol_rel': 1e-4, 

285 'restol_min': 1e-10, 

286 'restart_at_maxiter': True, 

287 'factor_if_not_converged': 4.0, 

288 }, 

289 } 

290 else: 

291 force_params = {} 

292 stats, _, _ = stats_analyser.single_run( 

293 strategy=strategy(useMPI=False), 

294 force_params=force_params, 

295 hook_class=[LogLocalErrorPostStep, LogData, LogWork], 

296 ) 

297 plot_error(stats, axs[1], axs[2], strategy()) 

298 

299 if strategy == BaseStrategy: 

300 u = get_sorted(stats, type='u', recomputed=False) 

301 axs[0].plot([me[0] for me in u], [me[1][0] for me in u], color='black', label=r'$u$') 

302 

303 axs[2].set_xlabel(r'$t$') 

304 axs[0].set_ylabel('solution') 

305 axs[2].legend(frameon=JOURNAL == 'JSC_beamer') 

306 axs[1].legend(frameon=True) 

307 savefig(fig, 'adaptivity') 

308 

309 

310def plot_fault_vdp(bit=0): # pragma: no cover 

311 """ 

312 Make a plot showing the impact of a fault on van der Pol without any resilience. 

313 The faults are inserted in the last iteration in the last node in u_t such that you can best see the impact. 

314 

315 Args: 

316 bit (int): The bit that you want to flip 

317 

318 Returns: 

319 None 

320 """ 

321 from pySDC.projects.Resilience.fault_stats import ( 

322 FaultStats, 

323 BaseStrategy, 

324 ) 

325 from pySDC.projects.Resilience.hook import LogData 

326 

327 stats_analyser = FaultStats( 

328 prob=run_vdp, 

329 strategies=[BaseStrategy()], 

330 faults=[False, True], 

331 reload=True, 

332 recovery_thresh=1.1, 

333 num_procs=1, 

334 mode='combination', 

335 ) 

336 

337 my_setup_mpl() 

338 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.8, 0.5)) 

339 colors = ['blue', 'red', 'magenta'] 

340 ls = ['--', '-'] 

341 markers = ['*', '^'] 

342 do_faults = [False, True] 

343 superscripts = ['*', ''] 

344 subscripts = ['', 't', ''] 

345 

346 run = 779 + 12 * bit # for faults in u_t 

347 # run = 11 + 12 * bit # for faults in u 

348 

349 for i in range(len(do_faults)): 

350 stats, controller, Tend = stats_analyser.single_run( 

351 strategy=BaseStrategy(), 

352 run=run, 

353 faults=do_faults[i], 

354 hook_class=[LogData], 

355 ) 

356 u = get_sorted(stats, type='u') 

357 faults = get_sorted(stats, type='bitflip') 

358 for j in [0, 1]: 

359 ax.plot( 

360 [me[0] for me in u], 

361 [me[1][j] for me in u], 

362 ls=ls[i], 

363 color=colors[j], 

364 label=rf'$u^{ {superscripts[i]}} _{ {subscripts[j]}} $', 

365 marker=markers[j], 

366 markevery=60, 

367 ) 

368 for idx in range(len(faults)): 

369 ax.axvline(faults[idx][0], color='black', label='Fault', ls=':') 

370 print( 

371 f'Fault at t={faults[idx][0]:.2e}, iter={faults[idx][1][1]}, node={faults[idx][1][2]}, space={faults[idx][1][3]}, bit={faults[idx][1][4]}' 

372 ) 

373 ax.set_title(f'Fault in bit {faults[idx][1][4]}') 

374 

375 ax.legend(frameon=True, loc='lower left') 

376 ax.set_xlabel(r'$t$') 

377 savefig(fig, f'fault_bit_{bit}') 

378 

379 

380def plot_quench_solution(): # pragma: no cover 

381 """ 

382 Plot the solution of Quench problem over time 

383 

384 Returns: 

385 None 

386 """ 

387 my_setup_mpl() 

388 if JOURNAL == 'JSC_beamer': 

389 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9)) 

390 else: 

391 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 1.0, 0.45)) 

392 

393 strategy = BaseStrategy() 

394 

395 custom_description = strategy.get_custom_description(run_quench, num_procs=1) 

396 

397 stats, controller, _ = run_quench(custom_description=custom_description, Tend=strategy.get_Tend(run_quench)) 

398 

399 prob = controller.MS[0].levels[0].prob 

400 

401 u = get_sorted(stats, type='u', recomputed=False) 

402 

403 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color='black', label='$T$') 

404 ax.axhline(prob.u_thresh, label=r'$T_\mathrm{thresh}$', ls='--', color='grey', zorder=-1) 

405 ax.axhline(prob.u_max, label=r'$T_\mathrm{max}$', ls=':', color='grey', zorder=-1) 

406 

407 ax.set_xlabel(r'$t$') 

408 ax.legend(frameon=False) 

409 savefig(fig, 'quench_sol') 

410 

411 

412def plot_AC_solution(): # pragma: no cover 

413 from pySDC.projects.TOMS.AllenCahn_monitor import monitor 

414 

415 my_setup_mpl() 

416 if JOURNAL == 'JSC_beamer': 

417 raise NotImplementedError 

418 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9)) 

419 else: 

420 fig, axs = plt.subplots(1, 2, figsize=figsize_by_journal(JOURNAL, 1.0, 0.45)) 

421 

422 stats, _, _ = run_AC(Tend=0.032, hook_class=monitor) 

423 

424 u = get_sorted(stats, type='u') 

425 

426 computed_radius = get_sorted(stats, type='computed_radius') 

427 exact_radius = get_sorted(stats, type='exact_radius') 

428 axs[1].plot([me[0] for me in computed_radius], [me[1] for me in computed_radius], ls='-', label='numerical') 

429 axs[1].plot([me[0] for me in exact_radius], [me[1] for me in exact_radius], ls='--', color='black', label='exact') 

430 axs[1].axvline(0.025, ls=':', label=r'$t=0.025$', color='grey') 

431 axs[1].set_title('Radius over time') 

432 axs[1].set_xlabel('$t$') 

433 axs[1].legend(frameon=False) 

434 

435 im = axs[0].imshow(u[0][1], extent=(-0.5, 0.5, -0.5, 0.5)) 

436 fig.colorbar(im) 

437 axs[0].set_title(r'$u_0$') 

438 axs[0].set_xlabel('$x$') 

439 axs[0].set_ylabel('$y$') 

440 savefig(fig, 'AC_sol') 

441 

442 

443def plot_vdp_solution(): # pragma: no cover 

444 """ 

445 Plot the solution of van der Pol problem over time to illustrate the varying time scales. 

446 

447 Returns: 

448 None 

449 """ 

450 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

451 

452 my_setup_mpl() 

453 if JOURNAL == 'JSC_beamer': 

454 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 0.9)) 

455 else: 

456 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 1.0, 0.33)) 

457 

458 custom_description = {'convergence_controllers': {Adaptivity: {'e_tol': 1e-7}}} 

459 

460 stats, _, _ = run_vdp(custom_description=custom_description, Tend=28.6) 

461 

462 u = get_sorted(stats, type='u') 

463 ax.plot([me[0] for me in u], [me[1][0] for me in u], color='black') 

464 ax.set_ylabel(r'$u$') 

465 ax.set_xlabel(r'$t$') 

466 savefig(fig, 'vdp_sol') 

467 

468 

469def work_precision(): # pragma: no cover 

470 from pySDC.projects.Resilience.work_precision import ( 

471 all_problems, 

472 ) 

473 

474 all_params = { 

475 'record': False, 

476 'work_key': 't', 

477 'precision_key': 'e_global_rel', 

478 'plotting': True, 

479 'base_path': 'data/paper', 

480 } 

481 

482 for mode in ['compare_strategies', 'parallel_efficiency', 'RK_comp']: 

483 all_problems(**all_params, mode=mode) 

484 

485 

486def make_plots_for_TIME_X_website(): # pragma: no cover 

487 global JOURNAL, BASE_PATH 

488 JOURNAL = 'JSC_beamer' 

489 BASE_PATH = 'data/paper/time-x_website' 

490 

491 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 2.0 / 3.0)) 

492 plot_recovery_rate_recoverable_only(get_stats(run_vdp), fig, ax) 

493 savefig(fig, 'recovery_rate', format='png') 

494 

495 from pySDC.projects.Resilience.work_precision import vdp_stiffness_plot 

496 

497 vdp_stiffness_plot(base_path=BASE_PATH, format='png') 

498 

499 

500def make_plots_for_SIAM_CSE23(): # pragma: no cover 

501 """ 

502 Make plots for the SIAM talk 

503 """ 

504 global JOURNAL, BASE_PATH 

505 JOURNAL = 'JSC_beamer' 

506 BASE_PATH = 'data/paper/SIAMCSE23' 

507 

508 fig, ax = plt.subplots(figsize=figsize_by_journal(JOURNAL, 0.5, 3.0 / 4.0)) 

509 plot_recovery_rate_recoverable_only(get_stats(run_vdp), fig, ax) 

510 savefig(fig, 'recovery_rate') 

511 

512 plot_adaptivity_stuff() 

513 compare_recovery_rate_problems() 

514 plot_vdp_solution() 

515 

516 

517def make_plots_for_paper(): # pragma: no cover 

518 """ 

519 Make plots that are supposed to go in the paper. 

520 """ 

521 global JOURNAL, BASE_PATH 

522 JOURNAL = 'Springer_Numerical_Algorithms' 

523 BASE_PATH = 'data/paper' 

524 

525 plot_adaptivity_stuff() 

526 

527 work_precision() 

528 

529 plot_vdp_solution() 

530 plot_AC_solution() 

531 plot_quench_solution() 

532 

533 plot_recovery_rate(get_stats(run_vdp)) 

534 plot_fault_vdp(0) 

535 plot_fault_vdp(13) 

536 compare_recovery_rate_problems(num_procs=1, strategy_type='SDC') 

537 

538 

539def make_plots_for_notes(): # pragma: no cover 

540 """ 

541 Make plots for the notes for the website / GitHub 

542 """ 

543 global JOURNAL, BASE_PATH 

544 JOURNAL = 'Springer_Numerical_Algorithms' 

545 BASE_PATH = 'notes/Lorenz' 

546 

547 analyse_resilience(run_Lorenz, format='png') 

548 analyse_resilience(run_quench, format='png') 

549 

550 

551if __name__ == "__main__": 

552 # make_plots_for_notes() 

553 # make_plots_for_SIAM_CSE23() 

554 # make_plots_for_TIME_X_website() 

555 make_plots_for_paper()