Coverage for pySDC/projects/Resilience/quench.py: 49%

218 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1# script to run a quench problem 

2from pySDC.implementations.problem_classes.Quench import Quench, QuenchIMEX 

3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

4from pySDC.core.Hooks import hooks 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.projects.Resilience.hook import hook_collection, LogData 

7from pySDC.projects.Resilience.strategies import merge_descriptions 

8from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient 

9import numpy as np 

10 

11import matplotlib.pyplot as plt 

12from pySDC.core.Errors import ConvergenceError 

13 

14 

15class live_plot(hooks): # pragma: no cover 

16 """ 

17 This hook plots the solution and the non-linear part of the right hand side after every step. Keep in mind that using adaptivity will result in restarts, which is not marked in these plots. Prepare to see the temperature profile jumping back again after a restart. 

18 """ 

19 

20 def _plot_state(self, step, level_number): # pragma: no cover 

21 """ 

22 Plot the solution at all collocation nodes and the non-linear part of the right hand side 

23 

24 Args: 

25 step (pySDC.Step.step): The current step 

26 level_number (int): Number of current level 

27 

28 Returns: 

29 None 

30 """ 

31 L = step.levels[level_number] 

32 for ax in self.axs: 

33 ax.cla() 

34 # [self.axs[0].plot(L.prob.xv, L.u[i], label=f"node {i}") for i in range(len(L.u))] 

35 self.axs[0].plot(L.prob.xv, L.u[-1]) 

36 self.axs[0].axhline(L.prob.u_thresh, color='black') 

37 self.axs[1].plot(L.prob.xv, L.prob.eval_f_non_linear(L.u[-1], L.time)) 

38 self.axs[0].set_ylim(0, 0.025) 

39 self.fig.suptitle(f"t={L.time:.2e}, k={step.status.iter}") 

40 plt.pause(1e-1) 

41 

42 def pre_run(self, step, level_number): # pragma: no cover 

43 """ 

44 Setup a figure to plot into 

45 

46 Args: 

47 step (pySDC.Step.step): The current step 

48 level_number (int): Number of current level 

49 

50 Returns: 

51 None 

52 """ 

53 self.fig, self.axs = plt.subplots(1, 2, figsize=(10, 4)) 

54 

55 def post_step(self, step, level_number): # pragma: no cover 

56 """ 

57 Call the plotting function after the step 

58 

59 Args: 

60 step (pySDC.Step.step): The current step 

61 level_number (int): Number of current level 

62 

63 Returns: 

64 None 

65 """ 

66 self._plot_state(step, level_number) 

67 

68 

69def run_quench( 

70 custom_description=None, 

71 num_procs=1, 

72 Tend=6e2, 

73 hook_class=LogData, 

74 fault_stuff=None, 

75 custom_controller_params=None, 

76 imex=False, 

77 u0=None, 

78 t0=None, 

79 use_MPI=False, 

80 **kwargs, 

81): 

82 """ 

83 Run a toy problem of a superconducting magnet with a temperature leak with default parameters. 

84 

85 Args: 

86 custom_description (dict): Overwrite presets 

87 num_procs (int): Number of steps for MSSDC 

88 Tend (float): Time to integrate to 

89 hook_class (pySDC.Hook): A hook to store data 

90 fault_stuff (dict): A dictionary with information on how to add faults 

91 custom_controller_params (dict): Overwrite presets 

92 imex (bool): Solve the problem IMEX or fully implicit 

93 u0 (dtype_u): Initial value 

94 t0 (float): Starting time 

95 use_MPI (bool): Whether or not to use MPI 

96 

97 Returns: 

98 dict: The stats object 

99 controller: The controller 

100 bool: If the code crashed 

101 """ 

102 if custom_description is not None: 

103 problem_params = custom_description.get('problem_params', {}) 

104 if 'imex' in problem_params.keys(): 

105 imex = problem_params['imex'] 

106 problem_params.pop('imex', None) 

107 

108 # initialize level parameters 

109 level_params = {} 

110 level_params['dt'] = 10.0 

111 

112 # initialize sweeper parameters 

113 sweeper_params = {} 

114 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

115 sweeper_params['num_nodes'] = 3 

116 sweeper_params['QI'] = 'IE' 

117 sweeper_params['QE'] = 'PIC' 

118 

119 problem_params = { 

120 'newton_tol': 1e-9, 

121 'direct_solver': False, 

122 'order': 6, 

123 'nvars': 2**7, 

124 } 

125 

126 # initialize step parameters 

127 step_params = {} 

128 step_params['maxiter'] = 5 

129 

130 # initialize controller parameters 

131 controller_params = {} 

132 controller_params['logger_level'] = 30 

133 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) 

134 controller_params['mssdc_jac'] = False 

135 

136 if custom_controller_params is not None: 

137 controller_params = {**controller_params, **custom_controller_params} 

138 

139 # fill description dictionary for easy step instantiation 

140 description = {} 

141 description['problem_class'] = QuenchIMEX if imex else Quench 

142 description['problem_params'] = problem_params 

143 description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient 

144 description['sweeper_params'] = sweeper_params 

145 description['level_params'] = level_params 

146 description['step_params'] = step_params 

147 

148 if custom_description is not None: 

149 description = merge_descriptions(description, custom_description) 

150 

151 # set time parameters 

152 t0 = 0.0 if t0 is None else t0 

153 

154 # instantiate controller 

155 controller_args = { 

156 'controller_params': controller_params, 

157 'description': description, 

158 } 

159 

160 if use_MPI: 

161 from mpi4py import MPI 

162 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

163 

164 comm = kwargs.get('comm', MPI.COMM_WORLD) 

165 controller = controller_MPI(**controller_args, comm=comm) 

166 P = controller.S.levels[0].prob 

167 else: 

168 controller = controller_nonMPI(**controller_args, num_procs=num_procs) 

169 P = controller.MS[0].levels[0].prob 

170 

171 uinit = P.u_exact(t0) if u0 is None else u0 

172 

173 # insert faults 

174 if fault_stuff is not None: 

175 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

176 

177 prepare_controller_for_faults(controller, fault_stuff) 

178 

179 # call main function to get things done... 

180 crash = False 

181 try: 

182 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

183 except ConvergenceError as e: 

184 print(f'Warning: Premature termination! {e}') 

185 stats = controller.return_stats() 

186 crash = True 

187 return stats, controller, crash 

188 

189 

190def faults(seed=0): # pragma: no cover 

191 import matplotlib.pyplot as plt 

192 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

193 

194 fig, ax = plt.subplots(1, 1) 

195 

196 rng = np.random.RandomState(seed) 

197 fault_stuff = {'rng': rng, 'args': {}, 'rnd_args': {}} 

198 

199 controller_params = {'logger_level': 30} 

200 description = {'level_params': {'dt': 1e1}, 'step_params': {'maxiter': 5}} 

201 stats, controller, _ = run_quench(custom_controller_params=controller_params, custom_description=description) 

202 plot_solution_faults(stats, controller, ax, plot_lines=True, label='ref') 

203 

204 stats, controller, _ = run_quench( 

205 fault_stuff=fault_stuff, 

206 custom_controller_params=controller_params, 

207 ) 

208 plot_solution_faults(stats, controller, ax, label='fixed') 

209 

210 description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-7, 'dt_max': 1e2, 'dt_min': 1e-3}} 

211 stats, controller, _ = run_quench( 

212 fault_stuff=fault_stuff, custom_controller_params=controller_params, custom_description=description 

213 ) 

214 

215 plot_solution_faults(stats, controller, ax, label='adaptivity', ls='--') 

216 plt.show() 

217 

218 

219def plot_solution_faults(stats, controller, ax, plot_lines=False, **kwargs): # pragma: no cover 

220 u_ax = ax 

221 

222 u = get_sorted(stats, type='u', recomputed=False) 

223 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], **kwargs) 

224 

225 if plot_lines: 

226 P = controller.MS[0].levels[0].prob 

227 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$') 

228 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$') 

229 

230 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')] 

231 

232 u_ax.legend() 

233 u_ax.set_xlabel(r'$t$') 

234 u_ax.set_ylabel(r'$T$') 

235 

236 

237def get_crossing_time(stats, controller, num_points=5, inter_points=50, temperature_error_thresh=1e-5): 

238 """ 

239 Compute the time when the temperature threshold is crossed based on interpolation. 

240 

241 Args: 

242 stats (dict): The stats from a pySDC run 

243 controller (pySDC.Controller.controller): The controller 

244 num_points (int): The number of points in the solution you want to use for interpolation 

245 inter_points (int): The resolution of the interpolation 

246 temperature_error_thresh (float): The temperature error compared to the actual threshold you want to allow 

247 

248 Returns: 

249 float: The time when the temperature threshold is crossed 

250 """ 

251 from pySDC.core.Lagrange import LagrangeApproximation 

252 from pySDC.core.Collocation import CollBase 

253 

254 P = controller.MS[0].levels[0].prob 

255 u_thresh = P.u_thresh 

256 

257 u = get_sorted(stats, type='u', recomputed=False) 

258 temp = np.array([np.mean(me[1]) for me in u]) 

259 t = np.array([me[0] for me in u]) 

260 

261 crossing_index = np.arange(len(temp))[temp > u_thresh][0] 

262 

263 # interpolation stuff 

264 num_points = min([num_points, crossing_index * 2, len(temp) - crossing_index]) 

265 idx = np.arange(num_points) - num_points // 2 + crossing_index 

266 t_grid = t[idx] 

267 u_grid = temp[idx] 

268 t_inter = np.linspace(t_grid[0], t_grid[-1], inter_points) 

269 interpolator = LagrangeApproximation(points=t_grid) 

270 u_inter = interpolator.getInterpolationMatrix(t_inter) @ u_grid 

271 

272 crossing_inter = np.arange(len(u_inter))[u_inter > u_thresh][0] 

273 

274 temperature_error = abs(u_inter[crossing_inter] - u_thresh) 

275 

276 assert temperature_error < temp[crossing_index], "Temperature error is rising due to interpolation!" 

277 

278 if temperature_error > temperature_error_thresh and inter_points < 300: 

279 return get_crossing_time(stats, controller, num_points + 4, inter_points + 15, temperature_error_thresh) 

280 

281 return t_inter[crossing_inter] 

282 

283 

284def plot_solution(stats, controller): # pragma: no cover 

285 import matplotlib.pyplot as plt 

286 

287 fig, ax = plt.subplots(1, 1) 

288 u_ax = ax 

289 dt_ax = u_ax.twinx() 

290 

291 u = get_sorted(stats, type='u', recomputed=False) 

292 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], label=r'$T$') 

293 

294 dt = get_sorted(stats, type='dt', recomputed=False) 

295 dt_ax.plot([me[0] for me in dt], [me[1] for me in dt], color='black', ls='--') 

296 u_ax.plot([None], [None], color='black', ls='--', label=r'$\Delta t$') 

297 

298 if controller.useMPI: 

299 P = controller.S.levels[0].prob 

300 else: 

301 P = controller.MS[0].levels[0].prob 

302 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$') 

303 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$') 

304 

305 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')] 

306 

307 u_ax.legend() 

308 u_ax.set_xlabel(r'$t$') 

309 u_ax.set_ylabel(r'$T$') 

310 dt_ax.set_ylabel(r'$\Delta t$') 

311 

312 

313def compare_imex_full(plotting=False, leak_type='linear'): 

314 """ 

315 Compare the results of IMEX and fully implicit runs. 

316 

317 Args: 

318 plotting (bool): Plot the solution or not 

319 """ 

320 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

321 from pySDC.implementations.hooks.log_work import LogWork 

322 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun 

323 

324 maxiter = 5 

325 num_nodes = 3 

326 newton_maxiter = 99 

327 

328 res = {} 

329 rhs = {} 

330 error = {} 

331 

332 custom_description = {} 

333 custom_description['problem_params'] = { 

334 'newton_tol': 1e-10, 

335 'newton_maxiter': newton_maxiter, 

336 'nvars': 2**7, 

337 'leak_type': leak_type, 

338 } 

339 custom_description['step_params'] = {'maxiter': maxiter} 

340 custom_description['sweeper_params'] = {'num_nodes': num_nodes} 

341 custom_description['convergence_controllers'] = { 

342 Adaptivity: {'e_tol': 1e-7}, 

343 } 

344 

345 custom_controller_params = {'logger_level': 15} 

346 for imex in [False, True]: 

347 stats, controller, _ = run_quench( 

348 custom_description=custom_description, 

349 custom_controller_params=custom_controller_params, 

350 imex=imex, 

351 Tend=5e2, 

352 use_MPI=False, 

353 hook_class=[LogWork, LogGlobalErrorPostRun], 

354 ) 

355 

356 res[imex] = get_sorted(stats, type='u')[-1][1] 

357 newton_iter = [me[1] for me in get_sorted(stats, type='work_newton')] 

358 rhs[imex] = np.mean([me[1] for me in get_sorted(stats, type='work_rhs')]) // 1 

359 error[imex] = get_sorted(stats, type='e_global_post_run')[-1][1] 

360 

361 if imex: 

362 assert all(me == 0 for me in newton_iter), "IMEX is not supposed to do Newton iterations!" 

363 else: 

364 assert max(newton_iter) / num_nodes / maxiter <= newton_maxiter, "Took more Newton iterations than allowed!" 

365 if plotting: # pragma: no cover 

366 plot_solution(stats, controller) 

367 

368 diff = abs(res[True] - res[False]) 

369 thresh = 4e-3 

370 assert ( 

371 diff < thresh 

372 ), f"Difference between IMEX and fully-implicit too large! Got {diff:.2e}, allowed is only {thresh:.2e}!" 

373 prob = controller.MS[0].levels[0].prob 

374 assert ( 

375 max(res[True]) > prob.u_max 

376 ), f"Expected runaway to happen, but maximum temperature is {max(res[True]):.2e} < u_max={prob.u_max:.2e}!" 

377 

378 assert ( 

379 rhs[True] == rhs[False] 

380 ), f"Expected IMEX and fully implicit schemes to take the same number of right hand side evaluations per step, but got {rhs[True]} and {rhs[False]}!" 

381 

382 assert error[True] < 1.2e-4, f'Expected error of IMEX version to be less than 1.2e-4, but got e={error[True]:.2e}!' 

383 assert ( 

384 error[False] < 8e-5 

385 ), f'Expected error of fully implicit version to be less than 8e-5, but got e={error[False]:.2e}!' 

386 

387 

388def compare_reference_solutions_single(): 

389 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostStep, LogLocalErrorPostStep 

390 from pySDC.implementations.hooks.log_solution import LogSolution 

391 

392 types = ['DIRK', 'SDC', 'scipy'] 

393 types = ['scipy'] 

394 fig, ax = plt.subplots() 

395 error_ax = ax.twinx() 

396 Tend = 500 

397 

398 colors = ['black', 'teal', 'magenta'] 

399 

400 from pySDC.projects.Resilience.strategies import AdaptivityStrategy, merge_descriptions, DoubleAdaptivityStrategy 

401 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

402 

403 strategy = DoubleAdaptivityStrategy() 

404 

405 controller_params = {'logger_level': 15} 

406 

407 for j in range(len(types)): 

408 description = {} 

409 description['level_params'] = {'dt': 5.0, 'restol': 1e-10} 

410 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3} 

411 description['problem_params'] = { 

412 'leak_type': 'linear', 

413 'leak_transition': 'step', 

414 'nvars': 2**10, 

415 'reference_sol_type': types[j], 

416 'newton_tol': 1e-12, 

417 } 

418 

419 description['level_params'] = {'dt': 5.0, 'restol': -1} 

420 description = merge_descriptions(description, strategy.get_custom_description(run_quench, 1)) 

421 description['step_params'] = {'maxiter': 5} 

422 description['convergence_controllers'][Adaptivity]['e_tol'] = 1e-7 

423 

424 stats, controller, _ = run_quench( 

425 custom_description=description, 

426 hook_class=[LogGlobalErrorPostStep, LogLocalErrorPostStep, LogSolution], 

427 Tend=Tend, 

428 imex=False, 

429 custom_controller_params=controller_params, 

430 ) 

431 e_glob = get_sorted(stats, type='e_global_post_step', recomputed=False) 

432 e_loc = get_sorted(stats, type='e_local_post_step', recomputed=False) 

433 u = get_sorted(stats, type='u', recomputed=False) 

434 

435 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color=colors[j], label=f'{types[j]} reference') 

436 

437 error_ax.plot([me[0] for me in e_glob], [me[1] for me in e_glob], color=colors[j], ls='--') 

438 error_ax.plot([me[0] for me in e_loc], [me[1] for me in e_loc], color=colors[j], ls=':') 

439 

440 prob = controller.MS[0].levels[0].prob 

441 ax.axhline(prob.u_thresh, ls='-.', color='grey') 

442 ax.axhline(prob.u_max, ls='-.', color='grey') 

443 ax.plot([None], [None], ls='--', label=r'$e_\mathrm{global}$', color='grey') 

444 ax.plot([None], [None], ls=':', label=r'$e_\mathrm{local}$', color='grey') 

445 error_ax.set_yscale('log') 

446 ax.legend(frameon=False) 

447 ax.set_xlabel(r'$t$') 

448 ax.set_ylabel('solution') 

449 error_ax.set_ylabel('error') 

450 ax.set_title('Fully implicit quench problem') 

451 fig.tight_layout() 

452 fig.savefig('data/quench_refs_single.pdf', bbox_inches='tight') 

453 

454 

455def compare_reference_solutions(): 

456 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun, LogLocalErrorPostStep 

457 

458 types = ['DIRK', 'SDC', 'scipy'] 

459 fig, ax = plt.subplots() 

460 Tend = 500 

461 dt_list = [Tend / 2.0**me for me in [2, 3, 4, 5, 6, 7, 8, 9, 10]] 

462 

463 for j in range(len(types)): 

464 errors = [None] * len(dt_list) 

465 for i in range(len(dt_list)): 

466 description = {} 

467 description['level_params'] = {'dt': dt_list[i], 'restol': 1e-10} 

468 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3} 

469 description['problem_params'] = { 

470 'leak_type': 'linear', 

471 'leak_transition': 'step', 

472 'nvars': 2**10, 

473 'reference_sol_type': types[j], 

474 } 

475 

476 stats, controller, _ = run_quench( 

477 custom_description=description, 

478 hook_class=[LogGlobalErrorPostRun, LogLocalErrorPostStep], 

479 Tend=Tend, 

480 imex=False, 

481 ) 

482 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1] 

483 errors[i] = max([me[1] for me in get_sorted(stats, type='e_local_post_step', recomputed=False)]) 

484 print(errors) 

485 ax.loglog(dt_list, errors, label=f'{types[j]} reference') 

486 

487 ax.legend(frameon=False) 

488 ax.set_xlabel(r'$\Delta t$') 

489 ax.set_ylabel('global error') 

490 ax.set_title('Fully implicit quench problem') 

491 fig.tight_layout() 

492 fig.savefig('data/quench_refs.pdf', bbox_inches='tight') 

493 

494 

495def check_order(reference_sol_type='scipy'): 

496 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun 

497 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError 

498 

499 Tend = 500 

500 maxiter_list = [1, 2, 3, 4, 5] 

501 dt_list = [Tend / 2.0**me for me in [4, 5, 6, 7, 8]] 

502 

503 fig, ax = plt.subplots() 

504 

505 from pySDC.implementations.sweeper_classes.Runge_Kutta import DIRK43 

506 

507 controller_params = {'logger_level': 30} 

508 

509 colors = ['black', 'teal', 'magenta', 'orange', 'red'] 

510 for j in range(len(maxiter_list)): 

511 errors = [None] * len(dt_list) 

512 

513 for i in range(len(dt_list)): 

514 description = {} 

515 description['level_params'] = {'dt': dt_list[i]} 

516 description['step_params'] = {'maxiter': maxiter_list[j]} 

517 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3} 

518 description['problem_params'] = { 

519 'leak_type': 'linear', 

520 'leak_transition': 'step', 

521 'nvars': 2**7, 

522 'reference_sol_type': reference_sol_type, 

523 'newton_tol': 1e-10, 

524 'lintol': 1e-10, 

525 'direct_solver': True, 

526 } 

527 description['convergence_controllers'] = {EstimateEmbeddedError: {}} 

528 

529 # if maxiter_list[j] == 5: 

530 # description['sweeper_class'] = DIRK43 

531 # description['sweeper_params'] = {'maxiter': 1} 

532 

533 hook_class = [ 

534 # LogGlobalErrorPostRun, 

535 ] 

536 stats, controller, _ = run_quench( 

537 custom_description=description, 

538 hook_class=hook_class, 

539 Tend=Tend, 

540 imex=False, 

541 custom_controller_params=controller_params, 

542 ) 

543 errors[i] = max([me[1] for me in get_sorted(stats, type='error_embedded_estimate')]) 

544 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1] 

545 print(errors) 

546 ax.loglog(dt_list, errors, color=colors[j], label=f'{maxiter_list[j]} iterations') 

547 ax.loglog( 

548 dt_list, [errors[0] * (me / dt_list[0]) ** maxiter_list[j] for me in dt_list], color=colors[j], ls='--' 

549 ) 

550 

551 dt_list = np.array(dt_list) 

552 errors = np.array(errors) 

553 orders = np.log(errors[1:] / errors[:-1]) / np.log(dt_list[1:] / dt_list[:-1]) 

554 print(orders, np.mean(orders)) 

555 

556 # ax.loglog(dt_list, local_errors) 

557 ax.legend(frameon=False) 

558 ax.set_xlabel(r'$\Delta t$') 

559 ax.set_ylabel('global error') 

560 # ax.set_ylabel('max. local error') 

561 ax.set_title('Fully implicit quench problem') 

562 fig.tight_layout() 

563 fig.savefig(f'data/order_quench_{reference_sol_type}.pdf', bbox_inches='tight') 

564 

565 

566if __name__ == '__main__': 

567 # compare_reference_solutions_single() 

568 for reference_sol_type in ['scipy']: 

569 check_order(reference_sol_type=reference_sol_type) 

570 # faults(19) 

571 # get_crossing_time() 

572 # compare_imex_full(plotting=True) 

573 # iteration_counts() 

574 plt.show()