Coverage for pySDC/projects/Resilience/accuracy_check.py: 80%
142 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import matplotlib as mpl
2import matplotlib.pylab as plt
3import numpy as np
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError
7from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import (
8 EstimateExtrapolationErrorNonMPI,
9)
10from pySDC.core.hooks import Hooks
11from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep
12from pySDC.projects.Resilience.strategies import merge_descriptions
14import pySDC.helpers.plot_helper as plt_helper
15from pySDC.projects.Resilience.piline import run_piline
18class DoNothing(Hooks):
19 pass
22def setup_mpl(font_size=8):
23 """
24 Setup matplotlib to fit in with TeX scipt.
26 Args:
27 fontsize (int): Font size
29 Returns:
30 None
31 """
32 plt_helper.setup_mpl(reset=True)
33 # Set up plotting parameters
34 style_options = {
35 "axes.labelsize": 12, # LaTeX default is 10pt font.
36 "legend.fontsize": 13, # Make the legend/label fonts a little smaller
37 "axes.xmargin": 0.03,
38 "axes.ymargin": 0.03,
39 }
40 mpl.rcParams.update(style_options)
43def get_results_from_stats(stats, var, val, hook_class=LogLocalErrorPostStep):
44 """
45 Extract results from the stats are used to compute the order.
47 Args:
48 stats (dict): The stats object from a pySDC run
49 var (str): The variable to compute the order against
50 val (float): The value of var corresponding to this run
51 hook_class (pySDC.Hook): A hook such that we know what information is available
53 Returns:
54 dict: The information needed for the order plot
55 """
56 results = {
57 'e_embedded': 0.0,
58 'e_extrapolated': 0.0,
59 'e': 0.0,
60 var: val,
61 }
63 if hook_class == LogLocalErrorPostStep:
64 e_extrapolated = np.array(get_sorted(stats, type='error_extrapolation_estimate'))[:, 1]
65 e_embedded = np.array(get_sorted(stats, type='error_embedded_estimate'))[:, 1]
66 e_local = np.array(get_sorted(stats, type='e_local_post_step'))[:, 1]
68 if len(e_extrapolated[e_extrapolated != [None]]) > 0:
69 results['e_extrapolated'] = e_extrapolated[e_extrapolated != [None]][-1]
71 if len(e_local[e_local != [None]]) > 0:
72 results['e'] = max([e_local[e_local != [None]][-1], np.finfo(float).eps])
74 if len(e_embedded[e_embedded != [None]]) > 0:
75 results['e_embedded'] = e_embedded[e_embedded != [None]][-1]
77 return results
80def multiple_runs(
81 k=5,
82 serial=True,
83 Tend_fixed=None,
84 custom_description=None,
85 prob=run_piline,
86 dt_list=None,
87 hook_class=LogLocalErrorPostStep,
88 custom_controller_params=None,
89 var='dt',
90 avoid_restarts=False,
91 embedded_error_flavor=None,
92):
93 """
94 A simple test program to compute the order of accuracy.
96 Args:
97 k (int): Number of SDC sweeps
98 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes
99 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be
100 computed since a fixed number of steps will be performed.
101 custom_description (dict): Custom parameters to pass to the problem
102 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project)
103 dt_list (list): A list of values to check the order with
104 hook_class (pySDC.Hook): A hook for recording relevant information
105 custom_controller_params (dict): Custom parameters to pass to the problem
106 var (str): The variable to check the order against
107 avoid_restarts (bool): Mode of running adaptivity if applicable
108 embedded_error_flavor (str): Flavor for the estimation of embedded error
110 Returns:
111 dict: The errors for different values of var
112 """
114 # assemble list of dt
115 if dt_list is not None:
116 pass
117 elif Tend_fixed:
118 dt_list = 0.1 * 10.0 ** -(np.arange(3) / 2)
119 else:
120 dt_list = 0.01 * 10.0 ** -(np.arange(20) / 10.0)
122 num_procs = 1 if serial else 5
124 embedded_error_flavor = (
125 embedded_error_flavor if embedded_error_flavor else 'standard' if avoid_restarts else 'linearized'
126 )
128 # perform rest of the tests
129 for i in range(0, len(dt_list)):
130 desc = {
131 'step_params': {'maxiter': k},
132 'convergence_controllers': {
133 EstimateEmbeddedError.get_implementation(flavor=embedded_error_flavor, useMPI=False): {},
134 EstimateExtrapolationErrorNonMPI: {'no_storage': not serial},
135 },
136 }
138 # setup the variable we check the order against
139 if var == 'dt':
140 desc['level_params'] = {'dt': dt_list[i]}
141 elif var == 'e_tol':
142 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
144 desc['convergence_controllers'][Adaptivity] = {
145 'e_tol': dt_list[i],
146 'avoid_restarts': avoid_restarts,
147 'embedded_error_flavor': embedded_error_flavor,
148 }
150 if custom_description is not None:
151 desc = merge_descriptions(desc, custom_description)
152 Tend = Tend_fixed if Tend_fixed else 30 * dt_list[i]
153 stats, controller, _ = prob(
154 custom_description=desc,
155 num_procs=num_procs,
156 Tend=Tend,
157 hook_class=hook_class,
158 custom_controller_params=custom_controller_params,
159 )
161 level = controller.MS[-1].levels[-1]
162 e_glob = abs(level.prob.u_exact(t=level.time + level.dt) - level.u[-1])
163 e_local = abs(level.prob.u_exact(t=level.time + level.dt, u_init=level.u[0], t_init=level.time) - level.u[-1])
165 res_ = get_results_from_stats(stats, var, dt_list[i], hook_class)
166 res_['e_glob'] = e_glob
167 res_['e_local'] = e_local
169 if i == 0:
170 res = res_.copy()
171 for key in res.keys():
172 res[key] = [res[key]]
173 else:
174 for key in res_.keys():
175 res[key].append(res_[key])
176 return res
179def plot_order(res, ax, k):
180 """
181 Plot the order using results from `multiple_runs`.
183 Args:
184 res (dict): The results from `multiple_runs`
185 ax: Somewhere to plot
186 k (int): Number of iterations
188 Returns:
189 None
190 """
191 color = plt.rcParams['axes.prop_cycle'].by_key()['color'][k - 2]
193 key = 'e_local'
194 order = get_accuracy_order(res, key=key, thresh=1e-11)
195 label = f'k={k}, p={np.mean(order):.2f}'
196 ax.loglog(res['dt'], res[key], color=color, ls='-', label=label)
197 ax.set_xlabel(r'$\Delta t$')
198 ax.set_ylabel(r'$\epsilon$')
199 ax.legend(frameon=False, loc='lower right')
202def plot(res, ax, k, var='dt', keys=None):
203 """
204 Plot the order of various errors using the results from `multiple_runs`.
206 Args:
207 results (dict): the dictionary containing the errors
208 ax: Somewhere to plot
209 k (int): Number of SDC sweeps
210 var (str): The variable to compute the order against
211 keys (list): List of keys to plot from the results
213 Returns:
214 None
215 """
216 keys = keys if keys else ['e_embedded', 'e_extrapolated', 'e']
217 ls = ['-', ':', '-.']
218 color = plt.rcParams['axes.prop_cycle'].by_key()['color'][k - 2]
220 for i in range(len(keys)):
221 if all(me == 0.0 for me in res[keys[i]]):
222 continue
223 order = get_accuracy_order(res, key=keys[i], var=var)
224 if keys[i] == 'e_embedded':
225 label = rf'$k={ {np.mean(order):.2f}} $'
226 expect_order = k if var == 'dt' else 1.0
227 assert np.isclose(
228 np.mean(order), expect_order, atol=4e-1
229 ), f'Expected embedded error estimate to have order {expect_order} \
230 \
231but got {np.mean(order):.2f}'
233 elif keys[i] == 'e_extrapolated':
234 label = None
235 expect_order = k + 1 if var == 'dt' else 1 + 1 / k
236 assert np.isclose(
237 np.mean(order), expect_order, rtol=3e-1
238 ), f' \
239 Expected extrapolation error estimate to have order \
240{expect_order} but got {np.mean(order):.2f}'
241 else:
242 label = None
243 ax.loglog(res[var], res[keys[i]], color=color, ls=ls[i], label=label)
245 if var == 'dt':
246 ax.set_xlabel(r'$\Delta t$')
247 elif var == 'e_tol':
248 ax.set_xlabel(r'$\epsilon_\mathrm{TOL}$')
249 else:
250 ax.set_xlabel(var)
251 ax.set_ylabel(r'$\epsilon$')
252 ax.legend(frameon=False, loc='lower right')
255def get_accuracy_order(results, key='e_embedded', thresh=1e-14, var='dt'):
256 """
257 Routine to compute the order of accuracy in time
259 Args:
260 results (dict): the dictionary containing the errors
261 key (str): The key in the dictionary corresponding to a specific error
262 thresh (float): A threshold below which values are not entered into the order computation
263 var (str): The variable to compute the order against
265 Returns:
266 the list of orders
267 """
269 # retrieve the list of dt from results
270 assert var in results, f'ERROR: expecting the list of {var} in the results dictionary'
271 dt_list = sorted(results[var], reverse=True)
273 order = []
274 # loop over two consecutive errors/dt pairs
275 for i in range(1, len(dt_list)):
276 # compute order as log(prev_error/this_error)/log(this_dt/old_dt) <-- depends on the sorting of the list!
277 try:
278 if results[key][i] > thresh and results[key][i - 1] > thresh:
279 order.append(np.log(results[key][i] / results[key][i - 1]) / np.log(dt_list[i] / dt_list[i - 1]))
280 except TypeError:
281 print('Type Warning', results[key])
283 return order
286def plot_orders(
287 ax,
288 ks,
289 serial,
290 Tend_fixed=None,
291 custom_description=None,
292 prob=run_piline,
293 dt_list=None,
294 custom_controller_params=None,
295 embedded_error_flavor=None,
296):
297 """
298 Plot only the local error.
300 Args:
301 ax: Somewhere to plot
302 ks (list): List of sweeps
303 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes
304 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be
305 custom_description (dict): Custom parameters to pass to the problem
306 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project)
307 dt_list (list): A list of values to check the order with
308 custom_controller_params (dict): Custom parameters to pass to the problem
309 embedded_error_flavor (str): Flavor for the estimation of embedded error
311 Returns:
312 None
313 """
314 for i in range(len(ks)):
315 k = ks[i]
316 res = multiple_runs(
317 k=k,
318 serial=serial,
319 Tend_fixed=Tend_fixed,
320 custom_description=custom_description,
321 prob=prob,
322 dt_list=dt_list,
323 hook_class=DoNothing,
324 custom_controller_params=custom_controller_params,
325 embedded_error_flavor=embedded_error_flavor,
326 )
327 plot_order(res, ax, k)
330def plot_all_errors(
331 ax,
332 ks,
333 serial,
334 Tend_fixed=None,
335 custom_description=None,
336 prob=run_piline,
337 dt_list=None,
338 custom_controller_params=None,
339 var='dt',
340 avoid_restarts=False,
341 embedded_error_flavor=None,
342 keys=None,
343):
344 """
345 Make tests for plotting the error and plot a bunch of error estimates
347 Args:
348 ax: Somewhere to plot
349 ks (list): List of sweeps
350 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes
351 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be
352 custom_description (dict): Custom parameters to pass to the problem
353 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project)
354 dt_list (list): A list of values to check the order with
355 custom_controller_params (dict): Custom parameters to pass to the problem
356 var (str): The variable to compute the order against
357 avoid_restarts (bool): Mode of running adaptivity if applicable
358 embedded_error_flavor (str): Flavor for the estimation of embedded error
359 keys (list): List of keys to plot from the results
361 Returns:
362 None
363 """
364 for i in range(len(ks)):
365 k = ks[i]
366 res = multiple_runs(
367 k=k,
368 serial=serial,
369 Tend_fixed=Tend_fixed,
370 custom_description=custom_description,
371 prob=prob,
372 dt_list=dt_list,
373 custom_controller_params=custom_controller_params,
374 var=var,
375 avoid_restarts=avoid_restarts,
376 embedded_error_flavor=embedded_error_flavor,
377 )
379 # visualize results
380 plot(res, ax, k, var=var, keys=keys)
382 ax.plot([None, None], color='black', label=r'$\epsilon_\mathrm{embedded}$', ls='-')
383 ax.plot([None, None], color='black', label=r'$\epsilon_\mathrm{extrapolated}$', ls=':')
384 ax.plot([None, None], color='black', label=r'$e$', ls='-.')
385 ax.legend(frameon=False, loc='lower right')
388def check_order_with_adaptivity():
389 """
390 Test the order when running adaptivity.
391 Since we replace the step size with the tolerance, we check the order against this.
393 Irrespective of the number of sweeps we do, the embedded error estimate should scale linearly with the tolerance,
394 since it is supposed to match it as closely as possible.
396 The error estimate for the error of the last sweep, however will depend on the number of sweeps we do. The order
397 we expect is 1 + 1/k.
398 """
399 setup_mpl()
400 ks = [3, 2]
401 for serial in [True, False]:
402 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))
403 plot_all_errors(
404 ax,
405 ks,
406 serial,
407 Tend_fixed=5e-1,
408 var='e_tol',
409 dt_list=[1e-5, 5e-6],
410 avoid_restarts=False,
411 custom_controller_params={'logger_level': 30},
412 )
413 if serial:
414 fig.savefig('data/error_estimate_order_adaptivity.png', dpi=300, bbox_inches='tight')
415 else:
416 fig.savefig('data/error_estimate_order_adaptivity_parallel.png', dpi=300, bbox_inches='tight')
417 plt.close(fig)
420def check_order_against_step_size():
421 """
422 Check the order versus the step size for different numbers of sweeps.
423 """
424 setup_mpl()
425 ks = [4, 3, 2]
426 for serial in [True, False]:
427 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))
429 plot_all_errors(ax, ks, serial, Tend_fixed=1.0)
431 if serial:
432 fig.savefig('data/error_estimate_order.png', dpi=300, bbox_inches='tight')
433 else:
434 fig.savefig('data/error_estimate_order_parallel.png', dpi=300, bbox_inches='tight')
435 plt.close(fig)
438def main():
439 """Run various tests"""
440 check_order_with_adaptivity()
441 check_order_against_step_size()
444if __name__ == "__main__":
445 main()