Coverage for pySDC/projects/Resilience/quench.py: 49%
218 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1# script to run a quench problem
2from pySDC.implementations.problem_classes.Quench import Quench, QuenchIMEX
3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
4from pySDC.core.Hooks import hooks
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.projects.Resilience.hook import hook_collection, LogData
7from pySDC.projects.Resilience.strategies import merge_descriptions
8from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient
9import numpy as np
11import matplotlib.pyplot as plt
12from pySDC.core.Errors import ConvergenceError
15class live_plot(hooks): # pragma: no cover
16 """
17 This hook plots the solution and the non-linear part of the right hand side after every step. Keep in mind that using adaptivity will result in restarts, which is not marked in these plots. Prepare to see the temperature profile jumping back again after a restart.
18 """
20 def _plot_state(self, step, level_number): # pragma: no cover
21 """
22 Plot the solution at all collocation nodes and the non-linear part of the right hand side
24 Args:
25 step (pySDC.Step.step): The current step
26 level_number (int): Number of current level
28 Returns:
29 None
30 """
31 L = step.levels[level_number]
32 for ax in self.axs:
33 ax.cla()
34 # [self.axs[0].plot(L.prob.xv, L.u[i], label=f"node {i}") for i in range(len(L.u))]
35 self.axs[0].plot(L.prob.xv, L.u[-1])
36 self.axs[0].axhline(L.prob.u_thresh, color='black')
37 self.axs[1].plot(L.prob.xv, L.prob.eval_f_non_linear(L.u[-1], L.time))
38 self.axs[0].set_ylim(0, 0.025)
39 self.fig.suptitle(f"t={L.time:.2e}, k={step.status.iter}")
40 plt.pause(1e-1)
42 def pre_run(self, step, level_number): # pragma: no cover
43 """
44 Setup a figure to plot into
46 Args:
47 step (pySDC.Step.step): The current step
48 level_number (int): Number of current level
50 Returns:
51 None
52 """
53 self.fig, self.axs = plt.subplots(1, 2, figsize=(10, 4))
55 def post_step(self, step, level_number): # pragma: no cover
56 """
57 Call the plotting function after the step
59 Args:
60 step (pySDC.Step.step): The current step
61 level_number (int): Number of current level
63 Returns:
64 None
65 """
66 self._plot_state(step, level_number)
69def run_quench(
70 custom_description=None,
71 num_procs=1,
72 Tend=6e2,
73 hook_class=LogData,
74 fault_stuff=None,
75 custom_controller_params=None,
76 imex=False,
77 u0=None,
78 t0=None,
79 use_MPI=False,
80 **kwargs,
81):
82 """
83 Run a toy problem of a superconducting magnet with a temperature leak with default parameters.
85 Args:
86 custom_description (dict): Overwrite presets
87 num_procs (int): Number of steps for MSSDC
88 Tend (float): Time to integrate to
89 hook_class (pySDC.Hook): A hook to store data
90 fault_stuff (dict): A dictionary with information on how to add faults
91 custom_controller_params (dict): Overwrite presets
92 imex (bool): Solve the problem IMEX or fully implicit
93 u0 (dtype_u): Initial value
94 t0 (float): Starting time
95 use_MPI (bool): Whether or not to use MPI
97 Returns:
98 dict: The stats object
99 controller: The controller
100 bool: If the code crashed
101 """
102 if custom_description is not None:
103 problem_params = custom_description.get('problem_params', {})
104 if 'imex' in problem_params.keys():
105 imex = problem_params['imex']
106 problem_params.pop('imex', None)
108 # initialize level parameters
109 level_params = {}
110 level_params['dt'] = 10.0
112 # initialize sweeper parameters
113 sweeper_params = {}
114 sweeper_params['quad_type'] = 'RADAU-RIGHT'
115 sweeper_params['num_nodes'] = 3
116 sweeper_params['QI'] = 'IE'
117 sweeper_params['QE'] = 'PIC'
119 problem_params = {
120 'newton_tol': 1e-9,
121 'direct_solver': False,
122 'order': 6,
123 'nvars': 2**7,
124 }
126 # initialize step parameters
127 step_params = {}
128 step_params['maxiter'] = 5
130 # initialize controller parameters
131 controller_params = {}
132 controller_params['logger_level'] = 30
133 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
134 controller_params['mssdc_jac'] = False
136 if custom_controller_params is not None:
137 controller_params = {**controller_params, **custom_controller_params}
139 # fill description dictionary for easy step instantiation
140 description = {}
141 description['problem_class'] = QuenchIMEX if imex else Quench
142 description['problem_params'] = problem_params
143 description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient
144 description['sweeper_params'] = sweeper_params
145 description['level_params'] = level_params
146 description['step_params'] = step_params
148 if custom_description is not None:
149 description = merge_descriptions(description, custom_description)
151 # set time parameters
152 t0 = 0.0 if t0 is None else t0
154 # instantiate controller
155 controller_args = {
156 'controller_params': controller_params,
157 'description': description,
158 }
160 if use_MPI:
161 from mpi4py import MPI
162 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
164 comm = kwargs.get('comm', MPI.COMM_WORLD)
165 controller = controller_MPI(**controller_args, comm=comm)
166 P = controller.S.levels[0].prob
167 else:
168 controller = controller_nonMPI(**controller_args, num_procs=num_procs)
169 P = controller.MS[0].levels[0].prob
171 uinit = P.u_exact(t0) if u0 is None else u0
173 # insert faults
174 if fault_stuff is not None:
175 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
177 prepare_controller_for_faults(controller, fault_stuff)
179 # call main function to get things done...
180 crash = False
181 try:
182 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
183 except ConvergenceError as e:
184 print(f'Warning: Premature termination! {e}')
185 stats = controller.return_stats()
186 crash = True
187 return stats, controller, crash
190def faults(seed=0): # pragma: no cover
191 import matplotlib.pyplot as plt
192 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
194 fig, ax = plt.subplots(1, 1)
196 rng = np.random.RandomState(seed)
197 fault_stuff = {'rng': rng, 'args': {}, 'rnd_args': {}}
199 controller_params = {'logger_level': 30}
200 description = {'level_params': {'dt': 1e1}, 'step_params': {'maxiter': 5}}
201 stats, controller, _ = run_quench(custom_controller_params=controller_params, custom_description=description)
202 plot_solution_faults(stats, controller, ax, plot_lines=True, label='ref')
204 stats, controller, _ = run_quench(
205 fault_stuff=fault_stuff,
206 custom_controller_params=controller_params,
207 )
208 plot_solution_faults(stats, controller, ax, label='fixed')
210 description['convergence_controllers'] = {Adaptivity: {'e_tol': 1e-7, 'dt_max': 1e2, 'dt_min': 1e-3}}
211 stats, controller, _ = run_quench(
212 fault_stuff=fault_stuff, custom_controller_params=controller_params, custom_description=description
213 )
215 plot_solution_faults(stats, controller, ax, label='adaptivity', ls='--')
216 plt.show()
219def plot_solution_faults(stats, controller, ax, plot_lines=False, **kwargs): # pragma: no cover
220 u_ax = ax
222 u = get_sorted(stats, type='u', recomputed=False)
223 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], **kwargs)
225 if plot_lines:
226 P = controller.MS[0].levels[0].prob
227 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$')
228 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$')
230 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')]
232 u_ax.legend()
233 u_ax.set_xlabel(r'$t$')
234 u_ax.set_ylabel(r'$T$')
237def get_crossing_time(stats, controller, num_points=5, inter_points=50, temperature_error_thresh=1e-5):
238 """
239 Compute the time when the temperature threshold is crossed based on interpolation.
241 Args:
242 stats (dict): The stats from a pySDC run
243 controller (pySDC.Controller.controller): The controller
244 num_points (int): The number of points in the solution you want to use for interpolation
245 inter_points (int): The resolution of the interpolation
246 temperature_error_thresh (float): The temperature error compared to the actual threshold you want to allow
248 Returns:
249 float: The time when the temperature threshold is crossed
250 """
251 from pySDC.core.Lagrange import LagrangeApproximation
252 from pySDC.core.Collocation import CollBase
254 P = controller.MS[0].levels[0].prob
255 u_thresh = P.u_thresh
257 u = get_sorted(stats, type='u', recomputed=False)
258 temp = np.array([np.mean(me[1]) for me in u])
259 t = np.array([me[0] for me in u])
261 crossing_index = np.arange(len(temp))[temp > u_thresh][0]
263 # interpolation stuff
264 num_points = min([num_points, crossing_index * 2, len(temp) - crossing_index])
265 idx = np.arange(num_points) - num_points // 2 + crossing_index
266 t_grid = t[idx]
267 u_grid = temp[idx]
268 t_inter = np.linspace(t_grid[0], t_grid[-1], inter_points)
269 interpolator = LagrangeApproximation(points=t_grid)
270 u_inter = interpolator.getInterpolationMatrix(t_inter) @ u_grid
272 crossing_inter = np.arange(len(u_inter))[u_inter > u_thresh][0]
274 temperature_error = abs(u_inter[crossing_inter] - u_thresh)
276 assert temperature_error < temp[crossing_index], "Temperature error is rising due to interpolation!"
278 if temperature_error > temperature_error_thresh and inter_points < 300:
279 return get_crossing_time(stats, controller, num_points + 4, inter_points + 15, temperature_error_thresh)
281 return t_inter[crossing_inter]
284def plot_solution(stats, controller): # pragma: no cover
285 import matplotlib.pyplot as plt
287 fig, ax = plt.subplots(1, 1)
288 u_ax = ax
289 dt_ax = u_ax.twinx()
291 u = get_sorted(stats, type='u', recomputed=False)
292 u_ax.plot([me[0] for me in u], [np.mean(me[1]) for me in u], label=r'$T$')
294 dt = get_sorted(stats, type='dt', recomputed=False)
295 dt_ax.plot([me[0] for me in dt], [me[1] for me in dt], color='black', ls='--')
296 u_ax.plot([None], [None], color='black', ls='--', label=r'$\Delta t$')
298 if controller.useMPI:
299 P = controller.S.levels[0].prob
300 else:
301 P = controller.MS[0].levels[0].prob
302 u_ax.axhline(P.u_thresh, color='grey', ls='-.', label=r'$T_\mathrm{thresh}$')
303 u_ax.axhline(P.u_max, color='grey', ls=':', label=r'$T_\mathrm{max}$')
305 [ax.axvline(me[0], color='grey', label=f'fault at t={me[0]:.2f}') for me in get_sorted(stats, type='bitflip')]
307 u_ax.legend()
308 u_ax.set_xlabel(r'$t$')
309 u_ax.set_ylabel(r'$T$')
310 dt_ax.set_ylabel(r'$\Delta t$')
313def compare_imex_full(plotting=False, leak_type='linear'):
314 """
315 Compare the results of IMEX and fully implicit runs.
317 Args:
318 plotting (bool): Plot the solution or not
319 """
320 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
321 from pySDC.implementations.hooks.log_work import LogWork
322 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
324 maxiter = 5
325 num_nodes = 3
326 newton_maxiter = 99
328 res = {}
329 rhs = {}
330 error = {}
332 custom_description = {}
333 custom_description['problem_params'] = {
334 'newton_tol': 1e-10,
335 'newton_maxiter': newton_maxiter,
336 'nvars': 2**7,
337 'leak_type': leak_type,
338 }
339 custom_description['step_params'] = {'maxiter': maxiter}
340 custom_description['sweeper_params'] = {'num_nodes': num_nodes}
341 custom_description['convergence_controllers'] = {
342 Adaptivity: {'e_tol': 1e-7},
343 }
345 custom_controller_params = {'logger_level': 15}
346 for imex in [False, True]:
347 stats, controller, _ = run_quench(
348 custom_description=custom_description,
349 custom_controller_params=custom_controller_params,
350 imex=imex,
351 Tend=5e2,
352 use_MPI=False,
353 hook_class=[LogWork, LogGlobalErrorPostRun],
354 )
356 res[imex] = get_sorted(stats, type='u')[-1][1]
357 newton_iter = [me[1] for me in get_sorted(stats, type='work_newton')]
358 rhs[imex] = np.mean([me[1] for me in get_sorted(stats, type='work_rhs')]) // 1
359 error[imex] = get_sorted(stats, type='e_global_post_run')[-1][1]
361 if imex:
362 assert all(me == 0 for me in newton_iter), "IMEX is not supposed to do Newton iterations!"
363 else:
364 assert max(newton_iter) / num_nodes / maxiter <= newton_maxiter, "Took more Newton iterations than allowed!"
365 if plotting: # pragma: no cover
366 plot_solution(stats, controller)
368 diff = abs(res[True] - res[False])
369 thresh = 4e-3
370 assert (
371 diff < thresh
372 ), f"Difference between IMEX and fully-implicit too large! Got {diff:.2e}, allowed is only {thresh:.2e}!"
373 prob = controller.MS[0].levels[0].prob
374 assert (
375 max(res[True]) > prob.u_max
376 ), f"Expected runaway to happen, but maximum temperature is {max(res[True]):.2e} < u_max={prob.u_max:.2e}!"
378 assert (
379 rhs[True] == rhs[False]
380 ), f"Expected IMEX and fully implicit schemes to take the same number of right hand side evaluations per step, but got {rhs[True]} and {rhs[False]}!"
382 assert error[True] < 1.2e-4, f'Expected error of IMEX version to be less than 1.2e-4, but got e={error[True]:.2e}!'
383 assert (
384 error[False] < 8e-5
385 ), f'Expected error of fully implicit version to be less than 8e-5, but got e={error[False]:.2e}!'
388def compare_reference_solutions_single():
389 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostStep, LogLocalErrorPostStep
390 from pySDC.implementations.hooks.log_solution import LogSolution
392 types = ['DIRK', 'SDC', 'scipy']
393 types = ['scipy']
394 fig, ax = plt.subplots()
395 error_ax = ax.twinx()
396 Tend = 500
398 colors = ['black', 'teal', 'magenta']
400 from pySDC.projects.Resilience.strategies import AdaptivityStrategy, merge_descriptions, DoubleAdaptivityStrategy
401 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
403 strategy = DoubleAdaptivityStrategy()
405 controller_params = {'logger_level': 15}
407 for j in range(len(types)):
408 description = {}
409 description['level_params'] = {'dt': 5.0, 'restol': 1e-10}
410 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3}
411 description['problem_params'] = {
412 'leak_type': 'linear',
413 'leak_transition': 'step',
414 'nvars': 2**10,
415 'reference_sol_type': types[j],
416 'newton_tol': 1e-12,
417 }
419 description['level_params'] = {'dt': 5.0, 'restol': -1}
420 description = merge_descriptions(description, strategy.get_custom_description(run_quench, 1))
421 description['step_params'] = {'maxiter': 5}
422 description['convergence_controllers'][Adaptivity]['e_tol'] = 1e-7
424 stats, controller, _ = run_quench(
425 custom_description=description,
426 hook_class=[LogGlobalErrorPostStep, LogLocalErrorPostStep, LogSolution],
427 Tend=Tend,
428 imex=False,
429 custom_controller_params=controller_params,
430 )
431 e_glob = get_sorted(stats, type='e_global_post_step', recomputed=False)
432 e_loc = get_sorted(stats, type='e_local_post_step', recomputed=False)
433 u = get_sorted(stats, type='u', recomputed=False)
435 ax.plot([me[0] for me in u], [max(me[1]) for me in u], color=colors[j], label=f'{types[j]} reference')
437 error_ax.plot([me[0] for me in e_glob], [me[1] for me in e_glob], color=colors[j], ls='--')
438 error_ax.plot([me[0] for me in e_loc], [me[1] for me in e_loc], color=colors[j], ls=':')
440 prob = controller.MS[0].levels[0].prob
441 ax.axhline(prob.u_thresh, ls='-.', color='grey')
442 ax.axhline(prob.u_max, ls='-.', color='grey')
443 ax.plot([None], [None], ls='--', label=r'$e_\mathrm{global}$', color='grey')
444 ax.plot([None], [None], ls=':', label=r'$e_\mathrm{local}$', color='grey')
445 error_ax.set_yscale('log')
446 ax.legend(frameon=False)
447 ax.set_xlabel(r'$t$')
448 ax.set_ylabel('solution')
449 error_ax.set_ylabel('error')
450 ax.set_title('Fully implicit quench problem')
451 fig.tight_layout()
452 fig.savefig('data/quench_refs_single.pdf', bbox_inches='tight')
455def compare_reference_solutions():
456 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun, LogLocalErrorPostStep
458 types = ['DIRK', 'SDC', 'scipy']
459 fig, ax = plt.subplots()
460 Tend = 500
461 dt_list = [Tend / 2.0**me for me in [2, 3, 4, 5, 6, 7, 8, 9, 10]]
463 for j in range(len(types)):
464 errors = [None] * len(dt_list)
465 for i in range(len(dt_list)):
466 description = {}
467 description['level_params'] = {'dt': dt_list[i], 'restol': 1e-10}
468 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3}
469 description['problem_params'] = {
470 'leak_type': 'linear',
471 'leak_transition': 'step',
472 'nvars': 2**10,
473 'reference_sol_type': types[j],
474 }
476 stats, controller, _ = run_quench(
477 custom_description=description,
478 hook_class=[LogGlobalErrorPostRun, LogLocalErrorPostStep],
479 Tend=Tend,
480 imex=False,
481 )
482 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1]
483 errors[i] = max([me[1] for me in get_sorted(stats, type='e_local_post_step', recomputed=False)])
484 print(errors)
485 ax.loglog(dt_list, errors, label=f'{types[j]} reference')
487 ax.legend(frameon=False)
488 ax.set_xlabel(r'$\Delta t$')
489 ax.set_ylabel('global error')
490 ax.set_title('Fully implicit quench problem')
491 fig.tight_layout()
492 fig.savefig('data/quench_refs.pdf', bbox_inches='tight')
495def check_order(reference_sol_type='scipy'):
496 from pySDC.implementations.hooks.log_errors import LogGlobalErrorPostRun
497 from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError
499 Tend = 500
500 maxiter_list = [1, 2, 3, 4, 5]
501 dt_list = [Tend / 2.0**me for me in [4, 5, 6, 7, 8]]
503 fig, ax = plt.subplots()
505 from pySDC.implementations.sweeper_classes.Runge_Kutta import DIRK43
507 controller_params = {'logger_level': 30}
509 colors = ['black', 'teal', 'magenta', 'orange', 'red']
510 for j in range(len(maxiter_list)):
511 errors = [None] * len(dt_list)
513 for i in range(len(dt_list)):
514 description = {}
515 description['level_params'] = {'dt': dt_list[i]}
516 description['step_params'] = {'maxiter': maxiter_list[j]}
517 description['sweeper_params'] = {'QI': 'IE', 'num_nodes': 3}
518 description['problem_params'] = {
519 'leak_type': 'linear',
520 'leak_transition': 'step',
521 'nvars': 2**7,
522 'reference_sol_type': reference_sol_type,
523 'newton_tol': 1e-10,
524 'lintol': 1e-10,
525 'direct_solver': True,
526 }
527 description['convergence_controllers'] = {EstimateEmbeddedError: {}}
529 # if maxiter_list[j] == 5:
530 # description['sweeper_class'] = DIRK43
531 # description['sweeper_params'] = {'maxiter': 1}
533 hook_class = [
534 # LogGlobalErrorPostRun,
535 ]
536 stats, controller, _ = run_quench(
537 custom_description=description,
538 hook_class=hook_class,
539 Tend=Tend,
540 imex=False,
541 custom_controller_params=controller_params,
542 )
543 errors[i] = max([me[1] for me in get_sorted(stats, type='error_embedded_estimate')])
544 # errors[i] = get_sorted(stats, type='e_global_post_run')[-1][1]
545 print(errors)
546 ax.loglog(dt_list, errors, color=colors[j], label=f'{maxiter_list[j]} iterations')
547 ax.loglog(
548 dt_list, [errors[0] * (me / dt_list[0]) ** maxiter_list[j] for me in dt_list], color=colors[j], ls='--'
549 )
551 dt_list = np.array(dt_list)
552 errors = np.array(errors)
553 orders = np.log(errors[1:] / errors[:-1]) / np.log(dt_list[1:] / dt_list[:-1])
554 print(orders, np.mean(orders))
556 # ax.loglog(dt_list, local_errors)
557 ax.legend(frameon=False)
558 ax.set_xlabel(r'$\Delta t$')
559 ax.set_ylabel('global error')
560 # ax.set_ylabel('max. local error')
561 ax.set_title('Fully implicit quench problem')
562 fig.tight_layout()
563 fig.savefig(f'data/order_quench_{reference_sol_type}.pdf', bbox_inches='tight')
566if __name__ == '__main__':
567 # compare_reference_solutions_single()
568 for reference_sol_type in ['scipy']:
569 check_order(reference_sol_type=reference_sol_type)
570 # faults(19)
571 # get_crossing_time()
572 # compare_imex_full(plotting=True)
573 # iteration_counts()
574 plt.show()