Coverage for pySDC/projects/Resilience/collocation_adaptivity.py: 95%
132 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
2import matplotlib.pyplot as plt
3from matplotlib.colors import TABLEAU_COLORS
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.projects.Resilience.vdp import run_vdp
7from pySDC.projects.Resilience.advection import run_advection
8from pySDC.projects.Resilience.heat import run_heat
9from pySDC.projects.Resilience.hook import LogData
10from pySDC.projects.Resilience.accuracy_check import get_accuracy_order
11from pySDC.implementations.convergence_controller_classes.adaptive_collocation import AdaptiveCollocation
12from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
13 EstimateEmbeddedErrorCollocation,
14)
15from pySDC.core.hooks import Hooks
16from pySDC.implementations.hooks.log_errors import LogLocalErrorPostIter
17from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimatePostIter
20# define global parameters for running problems and plotting
21CMAP = list(TABLEAU_COLORS.values())
24Tend = 0.015
25base_params = {
26 'step_params': {'maxiter': 99},
27 'sweeper_params': {
28 'QI': 'LU',
29 'num_nodes': 4,
30 },
31 'level_params': {'restol': 1e-8, 'dt': Tend},
32}
34coll_params_inexact = {
35 'num_nodes': [2, 3, 4],
36 'restol': [1e-4, 1e-7, 1e-8],
37}
38coll_params_refinement = {
39 'num_nodes': [1, 2, 3, 4],
40}
41coll_params_reduce = {
42 'num_nodes': [4, 3, 2, 1],
43}
44coll_params_type = {
45 # 'quad_type': ['RADAU-RIGHT', 'GAUSS'],
46 'quad_type': ['GAUSS', 'RADAU-RIGHT', 'LOBATTO'],
47}
49special_params = {
50 'inexact': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_inexact}},
51 'refinement': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_refinement}},
52 'reduce': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_reduce}},
53 'standard': {},
54 'type': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_type}},
55}
58# define a few hooks
59class LogSweeperParams(Hooks):
60 """
61 Log the sweeper parameters after every iteration to check if the adaptive collocation convergence controller is
62 doing what it's supposed to.
63 """
65 def post_iteration(self, step, level_number):
66 """
67 Args:
68 step (pySDC.Step.step): the current step
69 level_number (int): the current level number
71 Returns:
72 None
73 """
74 super().post_iteration(step, level_number)
76 L = step.levels[level_number]
78 self.add_to_stats(
79 process=step.status.slot,
80 time=L.time,
81 level=L.level_index,
82 iter=step.status.iter,
83 sweep=L.status.sweep,
84 type='sweeper_params',
85 value=L.sweep.params.__dict__,
86 )
87 self.add_to_stats(
88 process=step.status.slot,
89 time=L.time,
90 level=L.level_index,
91 iter=step.status.iter,
92 sweep=L.status.sweep,
93 type='coll_order',
94 value=L.sweep.coll.order,
95 )
98# plotting functions
99def compare_adaptive_collocation(prob):
100 """
101 Run a problem with various modes of adaptive collocation.
103 Args:
104 prob (function): A problem from the resilience project to run
106 Returns:
107 None
108 """
109 fig, ax = plt.subplots()
110 node_ax = ax.twinx()
112 for i in range(len(special_params.keys())):
113 key = list(special_params.keys())[i]
114 custom_description = {**base_params, 'convergence_controllers': special_params[key]}
115 custom_controller_parameters = {'logger_level': 30}
116 stats, _, _ = prob(
117 Tend=Tend,
118 custom_description=custom_description,
119 custom_controller_params=custom_controller_parameters,
120 hook_class=[LogData, LogSweeperParams],
121 )
123 plot_residual(stats, ax, node_ax, label=key, color=CMAP[i])
126def plot_residual(stats, ax, node_ax, **kwargs):
127 """
128 Plot residual and nodes vs. iteration.
129 Also a test is performed to see if we can reproduce previously obtained results.
131 Args:
132 stats (pySDC.stats): The stats object of the run
133 ax (Matplotlib.pyplot.axes): Somewhere to plot
134 node_ax (Matplotlib.pyplot.axes): Somewhere to plot
136 Returns:
137 None
138 """
139 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter')
140 residual = get_sorted(stats, type='residual_post_iteration', sortby='iter')
142 # determine when the number of collocation nodes increased
143 nodes = [me[1]['num_nodes'] for me in sweeper_params]
145 # test if the expected outcome was achieved
146 label = kwargs['label']
147 expect = {
148 'inexact': [2, 2, 3, 3, 4],
149 'refinement': [1, 2, 2, 2, 2, 3, 3, 3, 4],
150 'reduce': [4, 4, 4, 4, 4, 3, 2, 2, 1],
151 'standard': [4, 4, 4, 4, 4],
152 'type': [4, 4, 4, 4, 4, 4],
153 }
154 assert np.allclose(
155 nodes, expect[label]
156 ), f"Unexpected distribution of nodes vs. iteration in {label}! Expected {expect[label]}, got {nodes}"
158 ax.plot([me[0] for me in residual], [me[1] for me in residual], **kwargs)
159 ax.set_yscale('log')
160 ax.legend(frameon=False)
161 ax.set_xlabel(r'$k$')
162 ax.set_ylabel(r'residual')
164 node_ax.plot([me[0] for me in sweeper_params], nodes, **kwargs, ls='--')
165 node_ax.set_ylabel(r'nodes')
168def check_order(prob, coll_name, ax, k_ax):
169 """
170 Make plot of the order of the collocation problems and check if they are as expected.
172 Args:
173 prob (function): A problem from the resilience project to run
174 coll_name (str): The name of the collocation refinement strategy
175 ax (Matplotlib.pyplot.axes): Somewhere to plot
176 k_ax (Matplotlib.pyplot.axes): Somewhere to plot
178 Returns:
179 None
180 """
181 dt_range = [2.0 ** (-i) for i in range(2, 11)]
183 res = []
185 label_keys = {
186 'type': 'quad_type',
187 }
189 for i in range(len(dt_range)):
190 new_params = {
191 'level_params': {'restol': 1e-9, 'dt': dt_range[i]},
192 'sweeper_params': {'num_nodes': 2, 'QI': 'IE'},
193 }
194 custom_description = {**base_params, 'convergence_controllers': special_params[coll_name], **new_params}
195 custom_controller_parameters = {'logger_level': 30}
196 stats, _, _ = prob(
197 Tend=dt_range[i],
198 custom_description=custom_description,
199 custom_controller_params=custom_controller_parameters,
200 hook_class=[LogData, LogSweeperParams, LogLocalErrorPostIter, LogEmbeddedErrorEstimatePostIter],
201 )
203 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter')
204 converged_solution = [
205 sweeper_params[i][1] != sweeper_params[i + 1][1] for i in range(len(sweeper_params) - 1)
206 ] + [True]
207 idx = np.arange(len(converged_solution))[converged_solution]
208 labels = [sweeper_params[i][1][label_keys.get(coll_name, 'num_nodes')] for i in idx]
209 e_loc = np.array([me[1] for me in get_sorted(stats, type='e_local_post_iteration', sortby='iter')])[
210 converged_solution
211 ]
213 e_em_raw = [
214 me[1] for me in get_sorted(stats, type='error_embedded_estimate_collocation_post_iteration', sortby='iter')
215 ]
216 e_em = np.array((e_em_raw + [None] if coll_name == 'refinement' else [None] + e_em_raw))
217 coll_order = np.array([me[1] for me in get_sorted(stats, type='coll_order', sortby='iter')])[converged_solution]
219 res += [(dt_range[i], e_loc, idx[1:] - idx[:-1], labels, coll_order, e_em)]
221 # assemble sth we can compute the order from
222 result = {'dt': [me[0] for me in res]}
223 embedded_errors = {'dt': [me[0] for me in res]}
224 num_sols = len(res[0][1])
225 for i in range(num_sols):
226 result[i] = [me[1][i] for me in res]
227 embedded_errors[i] = [me[5][i] for me in res]
229 label = res[0][3][i]
230 expected_order = res[0][4][i] + 1
232 ax.scatter(result['dt'], embedded_errors[i], color=CMAP[i])
234 for me in [result, embedded_errors]:
235 if None in me[i]:
236 continue
237 order = get_accuracy_order(me, key=i, thresh=1e-9)
238 assert np.isclose(
239 np.mean(order), expected_order, atol=0.3
240 ), f"Expected order: {expected_order}, got {np.mean(order):.2f}!"
241 ax.loglog(result['dt'], result[i], label=f'{label} nodes: order: {np.mean(order):.1f}', color=CMAP[i])
243 if i > 0:
244 extra_iter = [me[2][i - 1] for me in res]
245 k_ax.plot(result['dt'], extra_iter, ls='--', color=CMAP[i])
246 ax.legend(frameon=False)
247 ax.set_xlabel(r'$\Delta t$')
248 ax.set_ylabel(r'$e_\mathrm{local}$ (lines), $e_\mathrm{embedded}$ (dots)')
249 k_ax.set_ylabel(r'extra iterations')
252def order_stuff(prob):
253 fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True, sharey=True)
254 k_axs = []
255 modes = ['type', 'refinement', 'reduce']
256 for i in range(len(modes)):
257 k_axs += [axs.flatten()[i].twinx()]
258 check_order(prob, modes[i], axs.flatten()[i], k_axs[-1])
259 axs.flatten()[i].set_title(modes[i])
261 for i in range(2):
262 k_axs[i].set_ylabel('')
264 for ax in axs[1:]:
265 ax.set_xlabel('')
266 ax.set_ylabel('')
267 fig.tight_layout()
270def adaptivity_collocation(plotting=False):
271 from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityCollocation
273 e_tol = 1e-7
275 adaptive_coll_params = {
276 'num_nodes': [2, 3],
277 }
279 convergence_controllers = {}
280 convergence_controllers[AdaptivityCollocation] = {'adaptive_coll_params': adaptive_coll_params, 'e_tol': e_tol}
282 step_params = {}
283 step_params['maxiter'] = 99
285 level_params = {}
286 level_params['restol'] = 1e-8
288 description = {}
289 description['convergence_controllers'] = convergence_controllers
290 description['step_params'] = step_params
291 description['level_params'] = level_params
293 controller_params = {'logger_level': 30}
295 stats, controller, _ = run_vdp(custom_description=description, custom_controller_params=controller_params)
297 e_em = get_sorted(stats, type='error_embedded_estimate_collocation', recomputed=False)
298 assert (
299 max([me[1] for me in e_em]) <= e_tol
300 ), "Exceeded threshold for local tolerance when using collocation based adaptivity"
301 assert (
302 min([me[1] for me in e_em][1:-1]) >= e_tol / 10
303 ), "Over resolved problem when using collocation based adaptivity"
305 if plotting:
306 from pySDC.projects.Resilience.vdp import plot_step_sizes
308 fig, ax = plt.subplots()
310 plot_step_sizes(stats, ax, 'error_embedded_estimate_collocation')
313def main(plotting=False):
314 adaptivity_collocation(plotting)
315 order_stuff(run_advection)
316 compare_adaptive_collocation(run_vdp)
319if __name__ == "__main__":
320 main(True)
321 plt.show()