Coverage for pySDC / projects / Resilience / collocation_adaptivity.py: 95%
132 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-13 09:00 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-13 09:00 +0000
1import numpy as np
2import matplotlib.pyplot as plt
3from matplotlib.colors import TABLEAU_COLORS
5from pySDC.helpers.stats_helper import get_sorted
6from pySDC.projects.Resilience.vdp import run_vdp
7from pySDC.projects.Resilience.advection import run_advection
8from pySDC.projects.Resilience.heat import run_heat
9from pySDC.projects.Resilience.hook import LogData
10from pySDC.projects.Resilience.accuracy_check import get_accuracy_order
11from pySDC.implementations.convergence_controller_classes.adaptive_collocation import AdaptiveCollocation
12from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
13 EstimateEmbeddedErrorCollocation,
14)
15from pySDC.core.hooks import Hooks
16from pySDC.implementations.hooks.log_errors import LogLocalErrorPostIter
17from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimatePostIter
19# define global parameters for running problems and plotting
20CMAP = list(TABLEAU_COLORS.values())
23Tend = 0.015
24base_params = {
25 'step_params': {'maxiter': 99},
26 'sweeper_params': {
27 'QI': 'LU',
28 'num_nodes': 4,
29 },
30 'level_params': {'restol': 1e-8, 'dt': Tend},
31}
33coll_params_inexact = {
34 'num_nodes': [2, 3, 4],
35 'restol': [1e-4, 1e-7, 1e-8],
36}
37coll_params_refinement = {
38 'num_nodes': [1, 2, 3, 4],
39}
40coll_params_reduce = {
41 'num_nodes': [4, 3, 2, 1],
42}
43coll_params_type = {
44 # 'quad_type': ['RADAU-RIGHT', 'GAUSS'],
45 'quad_type': ['GAUSS', 'RADAU-RIGHT', 'LOBATTO'],
46}
48special_params = {
49 'inexact': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_inexact}},
50 'refinement': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_refinement}},
51 'reduce': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_reduce}},
52 'standard': {},
53 'type': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_type}},
54}
57# define a few hooks
58class LogSweeperParams(Hooks):
59 """
60 Log the sweeper parameters after every iteration to check if the adaptive collocation convergence controller is
61 doing what it's supposed to.
62 """
64 def post_iteration(self, step, level_number):
65 """
66 Args:
67 step (pySDC.Step.step): the current step
68 level_number (int): the current level number
70 Returns:
71 None
72 """
73 super().post_iteration(step, level_number)
75 L = step.levels[level_number]
77 self.add_to_stats(
78 process=step.status.slot,
79 time=L.time,
80 level=L.level_index,
81 iter=step.status.iter,
82 sweep=L.status.sweep,
83 type='sweeper_params',
84 value=L.sweep.params.__dict__,
85 )
86 self.add_to_stats(
87 process=step.status.slot,
88 time=L.time,
89 level=L.level_index,
90 iter=step.status.iter,
91 sweep=L.status.sweep,
92 type='coll_order',
93 value=L.sweep.coll.order,
94 )
97# plotting functions
98def compare_adaptive_collocation(prob):
99 """
100 Run a problem with various modes of adaptive collocation.
102 Args:
103 prob (function): A problem from the resilience project to run
105 Returns:
106 None
107 """
108 fig, ax = plt.subplots()
109 node_ax = ax.twinx()
111 for i in range(len(special_params.keys())):
112 key = list(special_params.keys())[i]
113 custom_description = {**base_params, 'convergence_controllers': special_params[key]}
114 custom_controller_parameters = {'logger_level': 30}
115 stats, _, _ = prob(
116 Tend=Tend,
117 custom_description=custom_description,
118 custom_controller_params=custom_controller_parameters,
119 hook_class=[LogData, LogSweeperParams],
120 )
122 plot_residual(stats, ax, node_ax, label=key, color=CMAP[i])
125def plot_residual(stats, ax, node_ax, **kwargs):
126 """
127 Plot residual and nodes vs. iteration.
128 Also a test is performed to see if we can reproduce previously obtained results.
130 Args:
131 stats (pySDC.stats): The stats object of the run
132 ax (Matplotlib.pyplot.axes): Somewhere to plot
133 node_ax (Matplotlib.pyplot.axes): Somewhere to plot
135 Returns:
136 None
137 """
138 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter')
139 residual = get_sorted(stats, type='residual_post_iteration', sortby='iter')
141 # determine when the number of collocation nodes increased
142 nodes = [me[1]['num_nodes'] for me in sweeper_params]
144 # test if the expected outcome was achieved
145 label = kwargs['label']
146 expect = {
147 'inexact': [2, 2, 3, 3, 4],
148 'refinement': [1, 2, 2, 2, 2, 3, 3, 3, 4],
149 'reduce': [4, 4, 4, 4, 4, 3, 2, 2, 1],
150 'standard': [4, 4, 4, 4, 4],
151 'type': [4, 4, 4, 4, 4, 4],
152 }
153 assert np.allclose(
154 nodes, expect[label]
155 ), f"Unexpected distribution of nodes vs. iteration in {label}! Expected {expect[label]}, got {nodes}"
157 ax.plot([me[0] for me in residual], [me[1] for me in residual], **kwargs)
158 ax.set_yscale('log')
159 ax.legend(frameon=False)
160 ax.set_xlabel(r'$k$')
161 ax.set_ylabel(r'residual')
163 node_ax.plot([me[0] for me in sweeper_params], nodes, **kwargs, ls='--')
164 node_ax.set_ylabel(r'nodes')
167def check_order(prob, coll_name, ax, k_ax):
168 """
169 Make plot of the order of the collocation problems and check if they are as expected.
171 Args:
172 prob (function): A problem from the resilience project to run
173 coll_name (str): The name of the collocation refinement strategy
174 ax (Matplotlib.pyplot.axes): Somewhere to plot
175 k_ax (Matplotlib.pyplot.axes): Somewhere to plot
177 Returns:
178 None
179 """
180 dt_range = [2.0 ** (-i) for i in range(2, 11)]
182 res = []
184 label_keys = {
185 'type': 'quad_type',
186 }
188 for i in range(len(dt_range)):
189 new_params = {
190 'level_params': {'restol': 1e-9, 'dt': dt_range[i]},
191 'sweeper_params': {'num_nodes': 2, 'QI': 'IE'},
192 }
193 custom_description = {**base_params, 'convergence_controllers': special_params[coll_name], **new_params}
194 custom_controller_parameters = {'logger_level': 30}
195 stats, _, _ = prob(
196 Tend=dt_range[i],
197 custom_description=custom_description,
198 custom_controller_params=custom_controller_parameters,
199 hook_class=[LogData, LogSweeperParams, LogLocalErrorPostIter, LogEmbeddedErrorEstimatePostIter],
200 )
202 sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter')
203 converged_solution = [
204 sweeper_params[i][1] != sweeper_params[i + 1][1] for i in range(len(sweeper_params) - 1)
205 ] + [True]
206 idx = np.arange(len(converged_solution))[converged_solution]
207 labels = [sweeper_params[i][1][label_keys.get(coll_name, 'num_nodes')] for i in idx]
208 e_loc = np.array([me[1] for me in get_sorted(stats, type='e_local_post_iteration', sortby='iter')])[
209 converged_solution
210 ]
212 e_em_raw = [
213 me[1] for me in get_sorted(stats, type='error_embedded_estimate_collocation_post_iteration', sortby='iter')
214 ]
215 e_em = np.array((e_em_raw + [None] if coll_name == 'refinement' else [None] + e_em_raw))
216 coll_order = np.array([me[1] for me in get_sorted(stats, type='coll_order', sortby='iter')])[converged_solution]
218 res += [(dt_range[i], e_loc, idx[1:] - idx[:-1], labels, coll_order, e_em)]
220 # assemble sth we can compute the order from
221 result = {'dt': [me[0] for me in res]}
222 embedded_errors = {'dt': [me[0] for me in res]}
223 num_sols = len(res[0][1])
224 for i in range(num_sols):
225 result[i] = [me[1][i] for me in res]
226 embedded_errors[i] = [me[5][i] for me in res]
228 label = res[0][3][i]
229 expected_order = res[0][4][i] + 1
231 ax.scatter(result['dt'], embedded_errors[i], color=CMAP[i])
233 for me in [result, embedded_errors]:
234 if None in me[i]:
235 continue
236 order = get_accuracy_order(me, key=i, thresh=1e-9)
237 assert np.isclose(
238 np.mean(order), expected_order, atol=0.3
239 ), f"Expected order: {expected_order}, got {np.mean(order):.2f}!"
240 ax.loglog(result['dt'], result[i], label=f'{label} nodes: order: {np.mean(order):.1f}', color=CMAP[i])
242 if i > 0:
243 extra_iter = [me[2][i - 1] for me in res]
244 k_ax.plot(result['dt'], extra_iter, ls='--', color=CMAP[i])
245 ax.legend(frameon=False)
246 ax.set_xlabel(r'$\Delta t$')
247 ax.set_ylabel(r'$e_\mathrm{local}$ (lines), $e_\mathrm{embedded}$ (dots)')
248 k_ax.set_ylabel(r'extra iterations')
251def order_stuff(prob):
252 fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True, sharey=True)
253 k_axs = []
254 modes = ['type', 'refinement', 'reduce']
255 for i in range(len(modes)):
256 k_axs += [axs.flatten()[i].twinx()]
257 check_order(prob, modes[i], axs.flatten()[i], k_axs[-1])
258 axs.flatten()[i].set_title(modes[i])
260 for i in range(2):
261 k_axs[i].set_ylabel('')
263 for ax in axs[1:]:
264 ax.set_xlabel('')
265 ax.set_ylabel('')
266 fig.tight_layout()
269def adaptivity_collocation(plotting=False):
270 from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityCollocation
272 e_tol = 1e-7
274 adaptive_coll_params = {
275 'num_nodes': [2, 3],
276 }
278 convergence_controllers = {}
279 convergence_controllers[AdaptivityCollocation] = {'adaptive_coll_params': adaptive_coll_params, 'e_tol': e_tol}
281 step_params = {}
282 step_params['maxiter'] = 99
284 level_params = {}
285 level_params['restol'] = 1e-8
287 description = {}
288 description['convergence_controllers'] = convergence_controllers
289 description['step_params'] = step_params
290 description['level_params'] = level_params
292 controller_params = {'logger_level': 30}
294 stats, controller, _ = run_vdp(custom_description=description, custom_controller_params=controller_params)
296 e_em = get_sorted(stats, type='error_embedded_estimate_collocation', recomputed=False)
297 assert (
298 max([me[1] for me in e_em]) <= e_tol
299 ), "Exceeded threshold for local tolerance when using collocation based adaptivity"
300 assert (
301 min([me[1] for me in e_em][1:-1]) >= e_tol / 10
302 ), "Over resolved problem when using collocation based adaptivity"
304 if plotting:
305 from pySDC.projects.Resilience.vdp import plot_step_sizes
307 fig, ax = plt.subplots()
309 plot_step_sizes(stats, ax, 'error_embedded_estimate_collocation')
312def main(plotting=False):
313 adaptivity_collocation(plotting)
314 order_stuff(run_advection)
315 compare_adaptive_collocation(run_vdp)
318if __name__ == "__main__":
319 main(True)
320 plt.show()