Coverage for pySDC/projects/Resilience/quench.py: 49%
217 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 09:13 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 09:13 +0000
1# script to run a quench problem
2from pySDC.implementations.problem_classes.Quench import Quench, QuenchIMEX
3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
4from pySDC.core.hooks import Hooks
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.projects.Resilience.hook import hook_collection, LogData
7from pySDC.projects.Resilience.strategies import merge_descriptions
8from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient
9import numpy as np
11import matplotlib.pyplot as plt
12from pySDC.core.errors import ConvergenceError
15class live_plot(Hooks): # pragma: no cover
16 """
17 This hook plots the solution and the non-linear part of the right hand side after every step. Keep in mind that using adaptivity will result in restarts, which is not marked in these plots. Prepare to see the temperature profile jumping back again after a restart.
18 """
20 def _plot_state(self, step, level_number): # pragma: no cover
21 """
22 Plot the solution at all collocation nodes and the non-linear part of the right hand side
24 Args:
25 step (pySDC.Step.step): The current step
26 level_number (int): Number of current level
28 Returns:
29 None
30 """
31 L = step.levels[level_number]
32 for ax in self.axs:
33 ax.cla()
34 # [self.axs[0].plot(L.prob.xv, L.u[i], label=f"node {i}") for i in range(len(L.u))]
35 self.axs[0].plot(L.prob.xv, L.u[-1])
36 self.axs[0].axhline(L.prob.u_thresh, color='black')
37 self.axs[1].plot(L.prob.xv, L.prob.eval_f_non_linear(L.u[-1], L.time))
38 self.axs[0].set_ylim(0, 0.025)
39 self.fig.suptitle(f"t={L.time:.2e}, k={step.status.iter}")
40 plt.pause(1e-1)
42 def pre_run(self, step, level_number): # pragma: no cover
43 """
44 Setup a figure to plot into
46 Args:
47 step (pySDC.Step.step): The current step
48 level_number (int): Number of current level
50 Returns:
51 None
52 """
53 self.fig, self.axs = plt.subplots(1, 2, figsize=(10, 4))
55 def post_step(self, step, level_number): # pragma: no cover
56 """
57 Call the plotting function after the step
59 Args:
60 step (pySDC.Step.step): The current step
61 level_number (int): Number of current level
63 Returns:
64 None
65 """
66 self._plot_state(step, level_number)
69def run_quench(
70 custom_description=None,
71 num_procs=1,
72 Tend=6e2,
73 hook_class=LogData,
74 fault_stuff=None,
75 custom_controller_params=None,
76 imex=False,
77 u0=None,
78 t0=None,
79 use_MPI=False,
80 **kwargs,
81):
82 """
83 Run a toy problem of a superconducting magnet with a temperature leak with default parameters.
85 Args:
86 custom_description (dict): Overwrite presets
87 num_procs (int): Number of steps for MSSDC
88 Tend (float): Time to integrate to
89 hook_class (pySDC.Hook): A hook to store data
90 fault_stuff (dict): A dictionary with information on how to add faults
91 custom_controller_params (dict): Overwrite presets
92 imex (bool): Solve the problem IMEX or fully implicit
93 u0 (dtype_u): Initial value
94 t0 (float): Starting time
95 use_MPI (bool): Whether or not to use MPI
97 Returns:
98 dict: The stats object
99 controller: The controller
100 bool: If the code crashed
101 """
102 if custom_description is not None:
103 problem_params = custom_description.get('problem_params', {})
104 if 'imex' in problem_params.keys():
105 imex = problem_params['imex']
106 problem_params.pop('imex', None)
108 level_params = {}
109 level_params['dt'] = 10.0
111 sweeper_params = {}
112 sweeper_params['quad_type'] = 'RADAU-RIGHT'
113 sweeper_params['num_nodes'] = 3
114 sweeper_params['QI'] = 'IE'
115 sweeper_params['QE'] = 'PIC'
117 problem_params = {
118 'newton_tol': 1e-9,
119 'direct_solver': False,
120 'order': 6,
121 'nvars': 2**7,
122 }
124 step_params = {}
125 step_params['maxiter'] = 5
127 controller_params = {}
128 controller_params['logger_level'] = 30
129 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
130 controller_params['mssdc_jac'] = False
132 if custom_controller_params is not None:
133 controller_params = {**controller_params, **custom_controller_params}
135 description = {}
136 description['problem_class'] = QuenchIMEX if imex else Quench
137 description['problem_params'] = problem_params
138 description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient
139 description['sweeper_params'] = sweeper_params
140 description['level_params'] = level_params
141 description['step_params'] = step_params
143 if custom_description is not None:
144 description = merge_descriptions(description, custom_description)
146 t0 = 0.0 if t0 is None else t0
148 controller_args = {
149 'controller_params': controller_params,
150 'description': description,
151 }
153 if use_MPI:
154 from mpi4py import MPI
155 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
157 comm = kwargs.get('comm', MPI.COMM_WORLD)
158 controller = controller_MPI(**controller_args, comm=comm)
159 P = controller.S.levels[0].prob
160 else:
161 controller = controller_nonMPI(**controller_args, num_procs=num_procs)
162 P = controller.MS[0].levels[0].prob
164 uinit = P.u_exact(t0) if u0 is None else u0
166 if fault_stuff is not None:
167 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
169 prepare_controller_for_faults(controller, fault_stuff)
171 crash = False
172 try:
173 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
174 except ConvergenceError as e:
175 print(f'Warning: Premature termination! {e}')
176 stats = controller.return_stats()
177 crash = True
178 return stats, controller, crash
181def faults(seed=0): # pragma: no cover
182 import matplotlib.pyplot as plt
183 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
185 fig, ax = plt.subplots(1, 1)
187 rng = np.random.RandomState(seed)
188 fault_stuff = {'rng': rng, 'args': {}, 'rnd_args': {}}
190 controller_params = {'logger_level': 30}
191 description = {'level_params': {'dt': 1e1}, 'step_params': {'maxiter': 5}}
192 stats, controller, _ = run_quench(custom_controller_params=controller_params, custom_description=description)
193 plot_solution_faults(stats, controller, ax, plot_lines=True, label='ref')
195 stats, controller, _ = run_quench(
196 fault_stuff=fault_stuff,
197 custom_controller_params=controller_params,
198 )
199 plot_solution_faults(stats, controller, ax, label='fixed')
201 description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-7, 'dt_max': 1e2, 'dt_min': 1e-3}}
202 stats, controller, _ = run_quench(
203 fault_stuff=fault_stuff, custom_controller_params=controller_params, custom_description=description
204 )
206 plot_solution_faults(stats, controller, ax, label='adaptivity', ls='--')
207 plt.show()
210def plot_solution_faults(stats, controller, ax, plot_lines=False, **kwargs): # pragma: no cover
211 u_ax = ax
213 u = get_sorted(stats, type='u', recomputed=False)
214 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], **kwargs)
216 if plot_lines:
217 P = controller.MS[0].levels[0].prob
218 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$')
219 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$')
221 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')]
223 u_ax.legend()
224 u_ax.set_xlabel(r'$t$')
225 u_ax.set_ylabel(r'$T$')
228def get_crossing_time(stats, controller, num_points=5, inter_points=50, temperature_error_thresh=1e-5):
229 """
230 Compute the time when the temperature threshold is crossed based on interpolation.
232 Args:
233 stats (dict): The stats from a pySDC run
234 controller (pySDC.Controller.controller): The controller
235 num_points (int): The number of points in the solution you want to use for interpolation
236 inter_points (int): The resolution of the interpolation
237 temperature_error_thresh (float): The temperature error compared to the actual threshold you want to allow
239 Returns:
240 float: The time when the temperature threshold is crossed
241 """
242 from qmat.lagrange import LagrangeApproximation
244 P = controller.MS[0].levels[0].prob
245 u_thresh = P.u_thresh
247 u = get_sorted(stats, type='u', recomputed=False)
248 temp = np.array([np.mean(me[1]) for me in u])
249 t = np.array([me[0] for me in u])
251 crossing_index = np.arange(len(temp))[temp > u_thresh][0]
253 # interpolation stuff
254 num_points = min([num_points, crossing_index * 2, len(temp) - crossing_index])
255 idx = np.arange(num_points) - num_points // 2 + crossing_index
256 t_grid = t[idx]
257 u_grid = temp[idx]
258 t_inter = np.linspace(t_grid[0], t_grid[-1], inter_points)
259 interpolator = LagrangeApproximation(points=t_grid)
260 u_inter = interpolator.getInterpolationMatrix(t_inter) @ u_grid
262 crossing_inter = np.arange(len(u_inter))[u_inter > u_thresh][0]
264 temperature_error = abs(u_inter[crossing_inter] - u_thresh)
266 assert temperature_error < temp[crossing_index], "Temperature error is rising due to interpolation!"
268 if temperature_error > temperature_error_thresh and inter_points < 300:
269 return get_crossing_time(stats, controller, num_points + 4, inter_points + 15, temperature_error_thresh)
271 return t_inter[crossing_inter]
274def plot_solution(stats, controller): # pragma: no cover
275 import matplotlib.pyplot as plt
277 fig, ax = plt.subplots(1, 1)
278 u_ax = ax
279 dt_ax = u_ax.twinx()
281 u = get_sorted(stats, type='u', recomputed=False)
282 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], label=r'$T$')
284 dt = get_sorted(stats, type='dt', recomputed=False)
285 dt_ax.plot([me[0] for me in dt], [me[1] for me in dt], color='black', ls='--')
286 u_ax.plot([None], [None], color='black', ls='--', label=r'$\Delta t$')
288 if controller.useMPI:
289 P = controller.S.levels[0].prob
290 else:
291 P = controller.MS[0].levels[0].prob
292 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$')
293 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$')
295 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')]
297 u_ax.legend()
298 u_ax.set_xlabel(r'$t$')
299 u_ax.set_ylabel(r'$T$')
300 dt_ax.set_ylabel(r'$\Delta t$')
303def compare_imex_full(plotting=False, leak_type='linear'):
304 """
305 Compare the results of IMEX and fully implicit runs.
307 Args:
308 plotting (bool): Plot the solution or not
309 """
310 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
311 from pySDC.implementations.hooks.log_work import LogWork
312 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
314 maxiter = 5
315 num_nodes = 3
316 newton_maxiter = 99
318 res = {}
319 rhs = {}
320 error = {}
322 custom_description = {}
323 custom_description['problem_params'] = {
324 'newton_tol': 1e-10,
325 'newton_maxiter': newton_maxiter,
326 'nvars': 2**7,
327 'leak_type': leak_type,
328 }
329 custom_description['step_params'] = {'maxiter': maxiter}
330 custom_description['sweeper_params'] = {'num_nodes': num_nodes}
331 custom_description['convergence_controllers'] = {
332 Adaptivity: {'e_tol': 1e-7},
333 }
335 custom_controller_params = {'logger_level': 15}
336 for imex in [False, True]:
337 stats, controller, _ = run_quench(
338 custom_description=custom_description,
339 custom_controller_params=custom_controller_params,
340 imex=imex,
341 Tend=5e2,
342 use_MPI=False,
343 hook_class=[LogWork, LogGlobalErrorPostRun],
344 )
346 res[imex] = get_sorted(stats, type='u')[-1][1]
347 newton_iter = [me[1] for me in get_sorted(stats, type='work_newton')]
348 rhs[imex] = np.mean([me[1] for me in get_sorted(stats, type='work_rhs')]) // 1
349 error[imex] = get_sorted(stats, type='e_global_post_run')[-1][1]
351 if imex:
352 assert all(me == 0 for me in newton_iter), "IMEX is not supposed to do Newton iterations!"
353 else:
354 assert max(newton_iter) / num_nodes / maxiter <= newton_maxiter, "Took more Newton iterations than allowed!"
355 if plotting: # pragma: no cover
356 plot_solution(stats, controller)
358 diff = abs(res[True] - res[False])
359 thresh = 4e-3
360 assert (
361 diff < thresh
362 ), f"Difference between IMEX and fully-implicit too large! Got {diff:.2e}, allowed is only {thresh:.2e}!"
363 prob = controller.MS[0].levels[0].prob
364 assert (
365 max(res[True]) > prob.u_max
366 ), f"Expected runaway to happen, but maximum temperature is {max(res[True]):.2e} < u_max={prob.u_max:.2e}!"
368 assert (
369 rhs[True] == rhs[False]
370 ), f"Expected IMEX and fully implicit schemes to take the same number of right hand side evaluations per step, but got {rhs[True]} and {rhs[False]}!"
372 assert error[True] < 1.2e-4, f'Expected error of IMEX version to be less than 1.2e-4, but got e={error[True]:.2e}!'
373 assert (
374 error[False] < 8e-5
375 ), f'Expected error of fully implicit version to be less than 8e-5, but got e={error[False]:.2e}!'
378def compare_reference_solutions_single():
379 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostStep, LogLocalErrorPostStep
380 from pySDC.implementations.hooks.log_solution import LogSolution
382 types = ['DIRK', 'SDC', 'scipy']
383 types = ['scipy']
384 fig, ax = plt.subplots()
385 error_ax = ax.twinx()
386 Tend = 500
388 colors = ['black', 'teal', 'magenta']
390 from pySDC.projects.Resilience.strategies import AdaptivityStrategy, merge_descriptions, DoubleAdaptivityStrategy
391 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
393 strategy = DoubleAdaptivityStrategy()
395 controller_params = {'logger_level': 15}
397 for j in range(len(types)):
398 description = {}
399 description['level_params'] = {'dt': 5.0, 'restol': 1e-10}
400 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3}
401 description['problem_params'] = {
402 'leak_type': 'linear',
403 'leak_transition': 'step',
404 'nvars': 2**10,
405 'reference_sol_type': types[j],
406 'newton_tol': 1e-12,
407 }
409 description['level_params'] = {'dt': 5.0, 'restol': -1}
410 description = merge_descriptions(description, strategy.get_custom_description(run_quench, 1))
411 description['step_params'] = {'maxiter': 5}
412 description['convergence_controllers'][Adaptivity]['e_tol'] = 1e-7
414 stats, controller, _ = run_quench(
415 custom_description=description,
416 hook_class=[LogGlobalErrorPostStep, LogLocalErrorPostStep, LogSolution],
417 Tend=Tend,
418 imex=False,
419 custom_controller_params=controller_params,
420 )
421 e_glob = get_sorted(stats, type='e_global_post_step', recomputed=False)
422 e_loc = get_sorted(stats, type='e_local_post_step', recomputed=False)
423 u = get_sorted(stats, type='u', recomputed=False)
425 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color=colors[j], label=f'{types[j]} reference')
427 error_ax.plot([me[0] for me in e_glob], [me[1] for me in e_glob], color=colors[j], ls='--')
428 error_ax.plot([me[0] for me in e_loc], [me[1] for me in e_loc], color=colors[j], ls=':')
430 prob = controller.MS[0].levels[0].prob
431 ax.axhline(prob.u_thresh, ls='-.', color='grey')
432 ax.axhline(prob.u_max, ls='-.', color='grey')
433 ax.plot([None], [None], ls='--', label=r'$e_\mathrm{global}$', color='grey')
434 ax.plot([None], [None], ls=':', label=r'$e_\mathrm{local}$', color='grey')
435 error_ax.set_yscale('log')
436 ax.legend(frameon=False)
437 ax.set_xlabel(r'$t$')
438 ax.set_ylabel('solution')
439 error_ax.set_ylabel('error')
440 ax.set_title('Fully implicit quench problem')
441 fig.tight_layout()
442 fig.savefig('data/quench_refs_single.pdf', bbox_inches='tight')
445def compare_reference_solutions():
446 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun, LogLocalErrorPostStep
448 types = ['DIRK', 'SDC', 'scipy']
449 fig, ax = plt.subplots()
450 Tend = 500
451 dt_list = [Tend / 2.0**me for me in [2, 3, 4, 5, 6, 7, 8, 9, 10]]
453 for j in range(len(types)):
454 errors = [None] * len(dt_list)
455 for i in range(len(dt_list)):
456 description = {}
457 description['level_params'] = {'dt': dt_list[i], 'restol': 1e-10}
458 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3}
459 description['problem_params'] = {
460 'leak_type': 'linear',
461 'leak_transition': 'step',
462 'nvars': 2**10,
463 'reference_sol_type': types[j],
464 }
466 stats, controller, _ = run_quench(
467 custom_description=description,
468 hook_class=[LogGlobalErrorPostRun, LogLocalErrorPostStep],
469 Tend=Tend,
470 imex=False,
471 )
472 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1]
473 errors[i] = max([me[1] for me in get_sorted(stats, type='e_local_post_step', recomputed=False)])
474 print(errors)
475 ax.loglog(dt_list, errors, label=f'{types[j]} reference')
477 ax.legend(frameon=False)
478 ax.set_xlabel(r'$\Delta t$')
479 ax.set_ylabel('global error')
480 ax.set_title('Fully implicit quench problem')
481 fig.tight_layout()
482 fig.savefig('data/quench_refs.pdf', bbox_inches='tight')
485def check_order(reference_sol_type='scipy'):
486 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
487 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError
489 Tend = 500
490 maxiter_list = [1, 2, 3, 4, 5]
491 dt_list = [Tend / 2.0**me for me in [4, 5, 6, 7, 8]]
493 fig, ax = plt.subplots()
495 from pySDC.implementations.sweeper_classes.Runge_Kutta import DIRK43
497 controller_params = {'logger_level': 30}
499 colors = ['black', 'teal', 'magenta', 'orange', 'red']
500 for j in range(len(maxiter_list)):
501 errors = [None] * len(dt_list)
503 for i in range(len(dt_list)):
504 description = {}
505 description['level_params'] = {'dt': dt_list[i]}
506 description['step_params'] = {'maxiter': maxiter_list[j]}
507 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3}
508 description['problem_params'] = {
509 'leak_type': 'linear',
510 'leak_transition': 'step',
511 'nvars': 2**7,
512 'reference_sol_type': reference_sol_type,
513 'newton_tol': 1e-10,
514 'lintol': 1e-10,
515 'direct_solver': True,
516 }
517 description['convergence_controllers'] = {EstimateEmbeddedError: {}}
519 # if maxiter_list[j] == 5:
520 # description['sweeper_class'] = DIRK43
521 # description['sweeper_params'] = {'maxiter': 1}
523 hook_class = [
524 # LogGlobalErrorPostRun,
525 ]
526 stats, controller, _ = run_quench(
527 custom_description=description,
528 hook_class=hook_class,
529 Tend=Tend,
530 imex=False,
531 custom_controller_params=controller_params,
532 )
533 errors[i] = max([me[1] for me in get_sorted(stats, type='error_embedded_estimate')])
534 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1]
535 print(errors)
536 ax.loglog(dt_list, errors, color=colors[j], label=f'{maxiter_list[j]} iterations')
537 ax.loglog(
538 dt_list, [errors[0] * (me / dt_list[0]) ** maxiter_list[j] for me in dt_list], color=colors[j], ls='--'
539 )
541 dt_list = np.array(dt_list)
542 errors = np.array(errors)
543 orders = np.log(errors[1:] / errors[:-1]) / np.log(dt_list[1:] / dt_list[:-1])
544 print(orders, np.mean(orders))
546 # ax.loglog(dt_list, local_errors)
547 ax.legend(frameon=False)
548 ax.set_xlabel(r'$\Delta t$')
549 ax.set_ylabel('global error')
550 # ax.set_ylabel('max. local error')
551 ax.set_title('Fully implicit quench problem')
552 fig.tight_layout()
553 fig.savefig(f'data/order_quench_{reference_sol_type}.pdf', bbox_inches='tight')
556if __name__ == '__main__':
557 # compare_reference_solutions_single()
558 for reference_sol_type in ['scipy']:
559 check_order(reference_sol_type=reference_sol_type)
560 # faults(19)
561 # get_crossing_time()
562 # compare_imex_full(plotting=True)
563 # iteration_counts()
564 plt.show()