Coverage for pySDC/projects/Resilience/extrapolation_within_Q.py: 0%
53 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1import matplotlib.pyplot as plt
2import numpy as np
4from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import (
5 EstimateExtrapolationErrorWithinQ,
6)
7from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep
8from pySDC.helpers.stats_helper import get_sorted
10from pySDC.projects.Resilience.piline import run_piline
11from pySDC.projects.Resilience.advection import run_advection
12from pySDC.projects.Resilience.vdp import run_vdp
15def multiple_runs(prob, dts, num_nodes, quad_type='RADAU-RIGHT', QI='LU', useMPI=False):
16 """
17 Make multiple runs of a specific problem and record vital error information
19 Args:
20 prob (function): A problem from the resilience project to run
21 dts (list): The step sizes to run with
22 num_nodes (int): Number of nodes
23 quad_type (str): Type of nodes
25 Returns:
26 dict: Errors for multiple runs
27 int: Order of the collocation problem
28 """
29 description = {}
30 description['level_params'] = {'restol': 1e-10}
31 description['step_params'] = {'maxiter': 99}
32 description['sweeper_params'] = {'num_nodes': num_nodes, 'quad_type': quad_type}
33 description['convergence_controllers'] = {EstimateExtrapolationErrorWithinQ: {}}
35 if useMPI:
36 from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI, MPI
38 description['sweeper_class'] = generic_implicit_MPI
39 description['sweeper_params']['comm'] = MPI.COMM_WORLD.Split(MPI.COMM_WORLD.rank < num_nodes)
40 if MPI.COMM_WORLD.rank > num_nodes:
41 return None
43 if prob.__name__ == 'run_advection':
44 description['problem_params'] = {'order': 6, 'stencil_type': 'center'}
46 res = {}
48 for dt in dts:
49 description['level_params']['dt'] = dt
51 stats, controller, _ = prob(custom_description=description, Tend=5.0 * dt, hook_class=LogLocalErrorPostStep)
53 res[dt] = {}
54 res[dt]['e_loc'] = max([me[1] for me in get_sorted(stats, type='e_local_post_step')])
55 res[dt]['e_ex'] = max([me[1] for me in get_sorted(stats, type='error_extrapolation_estimate')])
57 coll_order = controller.MS[0].levels[0].sweep.coll.order
58 return res, coll_order
61def plot_and_compute_order(ax, res, num_nodes, coll_order):
62 """
63 Plot and compute the order from the multiple runs ran with `multiple_runs`. Also, it is tested if the expected order
64 is reached for the respective errors.
66 Args:
67 ax (Matplotlib.pyplot.axes): Somewhere to plot
68 res (dict): Result from `multiple_runs`
69 num_nodes (int): Number of nodes
70 coll_order (int): Order of the collocation problem
72 Returns:
73 None
74 """
75 dts = np.array(list(res.keys()))
76 keys = list(res[dts[0]].keys())
78 # local error is one order higher than global error
79 expected_order = {
80 'e_loc': coll_order + 1,
81 'e_ex': num_nodes + 1,
82 }
84 for key in keys:
85 errors = np.array([res[dt][key] for dt in dts])
87 mask = np.logical_and(errors < 1e-3, errors > 1e-10)
88 order = np.log(errors[mask][1:] / errors[mask][:-1]) / np.log(dts[mask][1:] / dts[mask][:-1])
90 if ax is not None:
91 ax.loglog(dts, errors, label=f'{key}: order={np.mean(order):.2f}')
93 assert np.isclose(
94 np.mean(order), expected_order[key], atol=0.5
95 ), f'Expected order {expected_order[key]} for {key}, but got {np.mean(order):.2e}!'
97 if ax is not None:
98 ax.legend(frameon=False)
101def check_order(ax, prob, dts, num_nodes, quad_type, **kwargs):
102 """
103 Check the order by calling `multiple_runs` and then `plot_and_compute_order`.
105 Args:
106 ax (Matplotlib.pyplot.axes): Somewhere to plot
107 prob (function): A problem from the resilience project to run
108 dts (list): The step sizes to run with
109 num_nodes (int): Number of nodes
110 quad_type (str): Type of nodes
111 """
112 res, coll_order = multiple_runs(prob, dts, num_nodes, quad_type, **kwargs)
113 plot_and_compute_order(ax, res, num_nodes, coll_order)
116def main():
117 fig, ax = plt.subplots()
118 num_nodes = 3
119 quad_type = 'RADAU-RIGHT'
120 check_order(ax, run_advection, [5e-1, 1e-1, 5e-2, 1e-2], num_nodes, quad_type, QI='MIN', useMPI=True)
121 plt.show()
124if __name__ == "__main__":
125 main()