Coverage for pySDC/projects/Resilience/quench.py: 49%

217 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1# script to run a quench problem 

2from pySDC.implementations.problem_classes.Quench import Quench, QuenchIMEX 

3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

4from pySDC.core.hooks import Hooks 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.projects.Resilience.hook import hook_collection, LogData 

7from pySDC.projects.Resilience.strategies import merge_descriptions 

8from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient 

9import numpy as np 

10 

11import matplotlib.pyplot as plt 

12from pySDC.core.errors import ConvergenceError 

13 

14 

15class live_plot(Hooks): # pragma: no cover 

16 """ 

17 This hook plots the solution and the non-linear part of the right hand side after every step. Keep in mind that using adaptivity will result in restarts, which is not marked in these plots. Prepare to see the temperature profile jumping back again after a restart. 

18 """ 

19 

20 def _plot_state(self, step, level_number): # pragma: no cover 

21 """ 

22 Plot the solution at all collocation nodes and the non-linear part of the right hand side 

23 

24 Args: 

25 step (pySDC.Step.step): The current step 

26 level_number (int): Number of current level 

27 

28 Returns: 

29 None 

30 """ 

31 L = step.levels[level_number] 

32 for ax in self.axs: 

33 ax.cla() 

34 # [self.axs[0].plot(L.prob.xv, L.u[i], label=f"node {i}") for i in range(len(L.u))] 

35 self.axs[0].plot(L.prob.xv, L.u[-1]) 

36 self.axs[0].axhline(L.prob.u_thresh, color='black') 

37 self.axs[1].plot(L.prob.xv, L.prob.eval_f_non_linear(L.u[-1], L.time)) 

38 self.axs[0].set_ylim(0, 0.025) 

39 self.fig.suptitle(f"t={L.time:.2e}, k={step.status.iter}") 

40 plt.pause(1e-1) 

41 

42 def pre_run(self, step, level_number): # pragma: no cover 

43 """ 

44 Setup a figure to plot into 

45 

46 Args: 

47 step (pySDC.Step.step): The current step 

48 level_number (int): Number of current level 

49 

50 Returns: 

51 None 

52 """ 

53 self.fig, self.axs = plt.subplots(1, 2, figsize=(10, 4)) 

54 

55 def post_step(self, step, level_number): # pragma: no cover 

56 """ 

57 Call the plotting function after the step 

58 

59 Args: 

60 step (pySDC.Step.step): The current step 

61 level_number (int): Number of current level 

62 

63 Returns: 

64 None 

65 """ 

66 self._plot_state(step, level_number) 

67 

68 

69def run_quench( 

70 custom_description=None, 

71 num_procs=1, 

72 Tend=6e2, 

73 hook_class=LogData, 

74 fault_stuff=None, 

75 custom_controller_params=None, 

76 imex=False, 

77 u0=None, 

78 t0=None, 

79 use_MPI=False, 

80 **kwargs, 

81): 

82 """ 

83 Run a toy problem of a superconducting magnet with a temperature leak with default parameters. 

84 

85 Args: 

86 custom_description (dict): Overwrite presets 

87 num_procs (int): Number of steps for MSSDC 

88 Tend (float): Time to integrate to 

89 hook_class (pySDC.Hook): A hook to store data 

90 fault_stuff (dict): A dictionary with information on how to add faults 

91 custom_controller_params (dict): Overwrite presets 

92 imex (bool): Solve the problem IMEX or fully implicit 

93 u0 (dtype_u): Initial value 

94 t0 (float): Starting time 

95 use_MPI (bool): Whether or not to use MPI 

96 

97 Returns: 

98 dict: The stats object 

99 controller: The controller 

100 bool: If the code crashed 

101 """ 

102 if custom_description is not None: 

103 problem_params = custom_description.get('problem_params', {}) 

104 if 'imex' in problem_params.keys(): 

105 imex = problem_params['imex'] 

106 problem_params.pop('imex', None) 

107 

108 level_params = {} 

109 level_params['dt'] = 10.0 

110 

111 sweeper_params = {} 

112 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

113 sweeper_params['num_nodes'] = 3 

114 sweeper_params['QI'] = 'IE' 

115 sweeper_params['QE'] = 'PIC' 

116 

117 problem_params = { 

118 'newton_tol': 1e-9, 

119 'direct_solver': False, 

120 'order': 6, 

121 'nvars': 2**7, 

122 } 

123 

124 step_params = {} 

125 step_params['maxiter'] = 5 

126 

127 controller_params = {} 

128 controller_params['logger_level'] = 30 

129 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) 

130 controller_params['mssdc_jac'] = False 

131 

132 if custom_controller_params is not None: 

133 controller_params = {**controller_params, **custom_controller_params} 

134 

135 description = {} 

136 description['problem_class'] = QuenchIMEX if imex else Quench 

137 description['problem_params'] = problem_params 

138 description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient 

139 description['sweeper_params'] = sweeper_params 

140 description['level_params'] = level_params 

141 description['step_params'] = step_params 

142 

143 if custom_description is not None: 

144 description = merge_descriptions(description, custom_description) 

145 

146 t0 = 0.0 if t0 is None else t0 

147 

148 controller_args = { 

149 'controller_params': controller_params, 

150 'description': description, 

151 } 

152 

153 if use_MPI: 

154 from mpi4py import MPI 

155 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

156 

157 comm = kwargs.get('comm', MPI.COMM_WORLD) 

158 controller = controller_MPI(**controller_args, comm=comm) 

159 P = controller.S.levels[0].prob 

160 else: 

161 controller = controller_nonMPI(**controller_args, num_procs=num_procs) 

162 P = controller.MS[0].levels[0].prob 

163 

164 uinit = P.u_exact(t0) if u0 is None else u0 

165 

166 if fault_stuff is not None: 

167 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

168 

169 prepare_controller_for_faults(controller, fault_stuff) 

170 

171 crash = False 

172 try: 

173 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

174 except ConvergenceError as e: 

175 print(f'Warning: Premature termination! {e}') 

176 stats = controller.return_stats() 

177 crash = True 

178 return stats, controller, crash 

179 

180 

181def faults(seed=0): # pragma: no cover 

182 import matplotlib.pyplot as plt 

183 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

184 

185 fig, ax = plt.subplots(1, 1) 

186 

187 rng = np.random.RandomState(seed) 

188 fault_stuff = {'rng': rng, 'args': {}, 'rnd_args': {}} 

189 

190 controller_params = {'logger_level': 30} 

191 description = {'level_params': {'dt': 1e1}, 'step_params': {'maxiter': 5}} 

192 stats, controller, _ = run_quench(custom_controller_params=controller_params, custom_description=description) 

193 plot_solution_faults(stats, controller, ax, plot_lines=True, label='ref') 

194 

195 stats, controller, _ = run_quench( 

196 fault_stuff=fault_stuff, 

197 custom_controller_params=controller_params, 

198 ) 

199 plot_solution_faults(stats, controller, ax, label='fixed') 

200 

201 description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-7, 'dt_max': 1e2, 'dt_min': 1e-3}} 

202 stats, controller, _ = run_quench( 

203 fault_stuff=fault_stuff, custom_controller_params=controller_params, custom_description=description 

204 ) 

205 

206 plot_solution_faults(stats, controller, ax, label='adaptivity', ls='--') 

207 plt.show() 

208 

209 

210def plot_solution_faults(stats, controller, ax, plot_lines=False, **kwargs): # pragma: no cover 

211 u_ax = ax 

212 

213 u = get_sorted(stats, type='u', recomputed=False) 

214 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], **kwargs) 

215 

216 if plot_lines: 

217 P = controller.MS[0].levels[0].prob 

218 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$') 

219 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$') 

220 

221 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')] 

222 

223 u_ax.legend() 

224 u_ax.set_xlabel(r'$t$') 

225 u_ax.set_ylabel(r'$T$') 

226 

227 

228def get_crossing_time(stats, controller, num_points=5, inter_points=50, temperature_error_thresh=1e-5): 

229 """ 

230 Compute the time when the temperature threshold is crossed based on interpolation. 

231 

232 Args: 

233 stats (dict): The stats from a pySDC run 

234 controller (pySDC.Controller.controller): The controller 

235 num_points (int): The number of points in the solution you want to use for interpolation 

236 inter_points (int): The resolution of the interpolation 

237 temperature_error_thresh (float): The temperature error compared to the actual threshold you want to allow 

238 

239 Returns: 

240 float: The time when the temperature threshold is crossed 

241 """ 

242 from qmat.lagrange import LagrangeApproximation 

243 

244 P = controller.MS[0].levels[0].prob 

245 u_thresh = P.u_thresh 

246 

247 u = get_sorted(stats, type='u', recomputed=False) 

248 temp = np.array([np.mean(me[1]) for me in u]) 

249 t = np.array([me[0] for me in u]) 

250 

251 crossing_index = np.arange(len(temp))[temp > u_thresh][0] 

252 

253 # interpolation stuff 

254 num_points = min([num_points, crossing_index * 2, len(temp) - crossing_index]) 

255 idx = np.arange(num_points) - num_points // 2 + crossing_index 

256 t_grid = t[idx] 

257 u_grid = temp[idx] 

258 t_inter = np.linspace(t_grid[0], t_grid[-1], inter_points) 

259 interpolator = LagrangeApproximation(points=t_grid) 

260 u_inter = interpolator.getInterpolationMatrix(t_inter) @ u_grid 

261 

262 crossing_inter = np.arange(len(u_inter))[u_inter > u_thresh][0] 

263 

264 temperature_error = abs(u_inter[crossing_inter] - u_thresh) 

265 

266 assert temperature_error < temp[crossing_index], "Temperature error is rising due to interpolation!" 

267 

268 if temperature_error > temperature_error_thresh and inter_points < 300: 

269 return get_crossing_time(stats, controller, num_points + 4, inter_points + 15, temperature_error_thresh) 

270 

271 return t_inter[crossing_inter] 

272 

273 

274def plot_solution(stats, controller): # pragma: no cover 

275 import matplotlib.pyplot as plt 

276 

277 fig, ax = plt.subplots(1, 1) 

278 u_ax = ax 

279 dt_ax = u_ax.twinx() 

280 

281 u = get_sorted(stats, type='u', recomputed=False) 

282 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], label=r'$T$') 

283 

284 dt = get_sorted(stats, type='dt', recomputed=False) 

285 dt_ax.plot([me[0] for me in dt], [me[1] for me in dt], color='black', ls='--') 

286 u_ax.plot([None], [None], color='black', ls='--', label=r'$\Delta t$') 

287 

288 if controller.useMPI: 

289 P = controller.S.levels[0].prob 

290 else: 

291 P = controller.MS[0].levels[0].prob 

292 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$') 

293 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$') 

294 

295 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')] 

296 

297 u_ax.legend() 

298 u_ax.set_xlabel(r'$t$') 

299 u_ax.set_ylabel(r'$T$') 

300 dt_ax.set_ylabel(r'$\Delta t$') 

301 

302 

303def compare_imex_full(plotting=False, leak_type='linear'): 

304 """ 

305 Compare the results of IMEX and fully implicit runs. 

306 

307 Args: 

308 plotting (bool): Plot the solution or not 

309 """ 

310 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

311 from pySDC.implementations.hooks.log_work import LogWork 

312 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun 

313 

314 maxiter = 5 

315 num_nodes = 3 

316 newton_maxiter = 99 

317 

318 res = {} 

319 rhs = {} 

320 error = {} 

321 

322 custom_description = {} 

323 custom_description['problem_params'] = { 

324 'newton_tol': 1e-10, 

325 'newton_maxiter': newton_maxiter, 

326 'nvars': 2**7, 

327 'leak_type': leak_type, 

328 } 

329 custom_description['step_params'] = {'maxiter': maxiter} 

330 custom_description['sweeper_params'] = {'num_nodes': num_nodes} 

331 custom_description['convergence_controllers'] = { 

332 Adaptivity: {'e_tol': 1e-7}, 

333 } 

334 

335 custom_controller_params = {'logger_level': 15} 

336 for imex in [False, True]: 

337 stats, controller, _ = run_quench( 

338 custom_description=custom_description, 

339 custom_controller_params=custom_controller_params, 

340 imex=imex, 

341 Tend=5e2, 

342 use_MPI=False, 

343 hook_class=[LogWork, LogGlobalErrorPostRun], 

344 ) 

345 

346 res[imex] = get_sorted(stats, type='u')[-1][1] 

347 newton_iter = [me[1] for me in get_sorted(stats, type='work_newton')] 

348 rhs[imex] = np.mean([me[1] for me in get_sorted(stats, type='work_rhs')]) // 1 

349 error[imex] = get_sorted(stats, type='e_global_post_run')[-1][1] 

350 

351 if imex: 

352 assert all(me == 0 for me in newton_iter), "IMEX is not supposed to do Newton iterations!" 

353 else: 

354 assert max(newton_iter) / num_nodes / maxiter <= newton_maxiter, "Took more Newton iterations than allowed!" 

355 if plotting: # pragma: no cover 

356 plot_solution(stats, controller) 

357 

358 diff = abs(res[True] - res[False]) 

359 thresh = 4e-3 

360 assert ( 

361 diff < thresh 

362 ), f"Difference between IMEX and fully-implicit too large! Got {diff:.2e}, allowed is only {thresh:.2e}!" 

363 prob = controller.MS[0].levels[0].prob 

364 assert ( 

365 max(res[True]) > prob.u_max 

366 ), f"Expected runaway to happen, but maximum temperature is {max(res[True]):.2e} < u_max={prob.u_max:.2e}!" 

367 

368 assert ( 

369 rhs[True] == rhs[False] 

370 ), f"Expected IMEX and fully implicit schemes to take the same number of right hand side evaluations per step, but got {rhs[True]} and {rhs[False]}!" 

371 

372 assert error[True] < 1.2e-4, f'Expected error of IMEX version to be less than 1.2e-4, but got e={error[True]:.2e}!' 

373 assert ( 

374 error[False] < 8e-5 

375 ), f'Expected error of fully implicit version to be less than 8e-5, but got e={error[False]:.2e}!' 

376 

377 

378def compare_reference_solutions_single(): 

379 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostStep, LogLocalErrorPostStep 

380 from pySDC.implementations.hooks.log_solution import LogSolution 

381 

382 types = ['DIRK', 'SDC', 'scipy'] 

383 types = ['scipy'] 

384 fig, ax = plt.subplots() 

385 error_ax = ax.twinx() 

386 Tend = 500 

387 

388 colors = ['black', 'teal', 'magenta'] 

389 

390 from pySDC.projects.Resilience.strategies import AdaptivityStrategy, merge_descriptions, DoubleAdaptivityStrategy 

391 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

392 

393 strategy = DoubleAdaptivityStrategy() 

394 

395 controller_params = {'logger_level': 15} 

396 

397 for j in range(len(types)): 

398 description = {} 

399 description['level_params'] = {'dt': 5.0, 'restol': 1e-10} 

400 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3} 

401 description['problem_params'] = { 

402 'leak_type': 'linear', 

403 'leak_transition': 'step', 

404 'nvars': 2**10, 

405 'reference_sol_type': types[j], 

406 'newton_tol': 1e-12, 

407 } 

408 

409 description['level_params'] = {'dt': 5.0, 'restol': -1} 

410 description = merge_descriptions(description, strategy.get_custom_description(run_quench, 1)) 

411 description['step_params'] = {'maxiter': 5} 

412 description['convergence_controllers'][Adaptivity]['e_tol'] = 1e-7 

413 

414 stats, controller, _ = run_quench( 

415 custom_description=description, 

416 hook_class=[LogGlobalErrorPostStep, LogLocalErrorPostStep, LogSolution], 

417 Tend=Tend, 

418 imex=False, 

419 custom_controller_params=controller_params, 

420 ) 

421 e_glob = get_sorted(stats, type='e_global_post_step', recomputed=False) 

422 e_loc = get_sorted(stats, type='e_local_post_step', recomputed=False) 

423 u = get_sorted(stats, type='u', recomputed=False) 

424 

425 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color=colors[j], label=f'{types[j]} reference') 

426 

427 error_ax.plot([me[0] for me in e_glob], [me[1] for me in e_glob], color=colors[j], ls='--') 

428 error_ax.plot([me[0] for me in e_loc], [me[1] for me in e_loc], color=colors[j], ls=':') 

429 

430 prob = controller.MS[0].levels[0].prob 

431 ax.axhline(prob.u_thresh, ls='-.', color='grey') 

432 ax.axhline(prob.u_max, ls='-.', color='grey') 

433 ax.plot([None], [None], ls='--', label=r'$e_\mathrm{global}$', color='grey') 

434 ax.plot([None], [None], ls=':', label=r'$e_\mathrm{local}$', color='grey') 

435 error_ax.set_yscale('log') 

436 ax.legend(frameon=False) 

437 ax.set_xlabel(r'$t$') 

438 ax.set_ylabel('solution') 

439 error_ax.set_ylabel('error') 

440 ax.set_title('Fully implicit quench problem') 

441 fig.tight_layout() 

442 fig.savefig('data/quench_refs_single.pdf', bbox_inches='tight') 

443 

444 

445def compare_reference_solutions(): 

446 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun, LogLocalErrorPostStep 

447 

448 types = ['DIRK', 'SDC', 'scipy'] 

449 fig, ax = plt.subplots() 

450 Tend = 500 

451 dt_list = [Tend / 2.0**me for me in [2, 3, 4, 5, 6, 7, 8, 9, 10]] 

452 

453 for j in range(len(types)): 

454 errors = [None] * len(dt_list) 

455 for i in range(len(dt_list)): 

456 description = {} 

457 description['level_params'] = {'dt': dt_list[i], 'restol': 1e-10} 

458 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3} 

459 description['problem_params'] = { 

460 'leak_type': 'linear', 

461 'leak_transition': 'step', 

462 'nvars': 2**10, 

463 'reference_sol_type': types[j], 

464 } 

465 

466 stats, controller, _ = run_quench( 

467 custom_description=description, 

468 hook_class=[LogGlobalErrorPostRun, LogLocalErrorPostStep], 

469 Tend=Tend, 

470 imex=False, 

471 ) 

472 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1] 

473 errors[i] = max([me[1] for me in get_sorted(stats, type='e_local_post_step', recomputed=False)]) 

474 print(errors) 

475 ax.loglog(dt_list, errors, label=f'{types[j]} reference') 

476 

477 ax.legend(frameon=False) 

478 ax.set_xlabel(r'$\Delta t$') 

479 ax.set_ylabel('global error') 

480 ax.set_title('Fully implicit quench problem') 

481 fig.tight_layout() 

482 fig.savefig('data/quench_refs.pdf', bbox_inches='tight') 

483 

484 

485def check_order(reference_sol_type='scipy'): 

486 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun 

487 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError 

488 

489 Tend = 500 

490 maxiter_list = [1, 2, 3, 4, 5] 

491 dt_list = [Tend / 2.0**me for me in [4, 5, 6, 7, 8]] 

492 

493 fig, ax = plt.subplots() 

494 

495 from pySDC.implementations.sweeper_classes.Runge_Kutta import DIRK43 

496 

497 controller_params = {'logger_level': 30} 

498 

499 colors = ['black', 'teal', 'magenta', 'orange', 'red'] 

500 for j in range(len(maxiter_list)): 

501 errors = [None] * len(dt_list) 

502 

503 for i in range(len(dt_list)): 

504 description = {} 

505 description['level_params'] = {'dt': dt_list[i]} 

506 description['step_params'] = {'maxiter': maxiter_list[j]} 

507 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3} 

508 description['problem_params'] = { 

509 'leak_type': 'linear', 

510 'leak_transition': 'step', 

511 'nvars': 2**7, 

512 'reference_sol_type': reference_sol_type, 

513 'newton_tol': 1e-10, 

514 'lintol': 1e-10, 

515 'direct_solver': True, 

516 } 

517 description['convergence_controllers'] = {EstimateEmbeddedError: {}} 

518 

519 # if maxiter_list[j] == 5: 

520 # description['sweeper_class'] = DIRK43 

521 # description['sweeper_params'] = {'maxiter': 1} 

522 

523 hook_class = [ 

524 # LogGlobalErrorPostRun, 

525 ] 

526 stats, controller, _ = run_quench( 

527 custom_description=description, 

528 hook_class=hook_class, 

529 Tend=Tend, 

530 imex=False, 

531 custom_controller_params=controller_params, 

532 ) 

533 errors[i] = max([me[1] for me in get_sorted(stats, type='error_embedded_estimate')]) 

534 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1] 

535 print(errors) 

536 ax.loglog(dt_list, errors, color=colors[j], label=f'{maxiter_list[j]} iterations') 

537 ax.loglog( 

538 dt_list, [errors[0] * (me / dt_list[0]) ** maxiter_list[j] for me in dt_list], color=colors[j], ls='--' 

539 ) 

540 

541 dt_list = np.array(dt_list) 

542 errors = np.array(errors) 

543 orders = np.log(errors[1:] / errors[:-1]) / np.log(dt_list[1:] / dt_list[:-1]) 

544 print(orders, np.mean(orders)) 

545 

546 # ax.loglog(dt_list, local_errors) 

547 ax.legend(frameon=False) 

548 ax.set_xlabel(r'$\Delta t$') 

549 ax.set_ylabel('global error') 

550 # ax.set_ylabel('max. local error') 

551 ax.set_title('Fully implicit quench problem') 

552 fig.tight_layout() 

553 fig.savefig(f'data/order_quench_{reference_sol_type}.pdf', bbox_inches='tight') 

554 

555 

556if __name__ == '__main__': 

557 # compare_reference_solutions_single() 

558 for reference_sol_type in ['scipy']: 

559 check_order(reference_sol_type=reference_sol_type) 

560 # faults(19) 

561 # get_crossing_time() 

562 # compare_imex_full(plotting=True) 

563 # iteration_counts() 

564 plt.show()