Coverage for pySDC / projects / PinTSimE / estimation_check.py: 100%
31 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 07:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 07:06 +0000
1import numpy as np
2from pathlib import Path
4from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted
5from pySDC.implementations.problem_classes.Battery import battery_n_capacitors, battery, battery_implicit
6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
9from pySDC.projects.PinTSimE.battery_model import runSimulation, plotStylingStuff
11import pySDC.helpers.plot_helper as plt_helper
13from pySDC.projects.PinTSimE.battery_model import LogEventBattery
14from pySDC.implementations.hooks.log_solution import LogSolution
15from pySDC.implementations.hooks.log_step_size import LogStepSize
16from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
19def run_estimation_check():
20 r"""
21 Generates plots to visualise results applying the Switch Estimator and Adaptivity to the battery models
22 containing.
24 Note
25 ----
26 Hardcoded solutions for battery models in `pySDC.projects.PinTSimE.hardcoded_solutions` are only computed for
27 ``dt_list=[1e-2, 1e-3]`` and ``M_fix=4``. Hence changing ``dt_list`` and ``M_fix`` to different values could arise
28 an ``AssertionError``.
29 """
31 Path("data").mkdir(parents=True, exist_ok=True)
33 # --- defines parameters for sweeper ----
34 M_fix = 4
35 sweeper_params = {
36 'num_nodes': M_fix,
37 'quad_type': 'LOBATTO',
38 'QI': 'IE',
39 }
41 # --- defines parameters for event detection and maximum number of iterations ----
42 handling_params = {
43 'restol': -1,
44 'maxiter': 8,
45 'max_restarts': 50,
46 'recomputed': False,
47 'tol_event': 1e-10,
48 'alpha': 0.96,
49 'exact_event_time_avail': None,
50 }
52 problem_classes = [battery, battery_implicit, battery_n_capacitors]
53 prob_class_names = [cls.__name__ for cls in problem_classes]
54 sweeper_classes = [imex_1st_order, generic_implicit, imex_1st_order]
56 # --- defines parameters for battery models ----
57 params_battery_1capacitor = {
58 'ncapacitors': 1,
59 'C': np.array([1.0]),
60 'alpha': 1.2,
61 'V_ref': np.array([1.0]),
62 }
64 params_battery_2capacitors = {
65 'ncapacitors': 2,
66 'C': np.array([1.0, 1.0]),
67 'alpha': 1.2,
68 'V_ref': np.array([1.0, 1.0]),
69 }
71 # --- parameters for each problem class are stored in this dictionary ----
72 all_params = {
73 'battery': {
74 'sweeper_params': sweeper_params,
75 'handling_params': handling_params,
76 'problem_params': params_battery_1capacitor,
77 },
78 'battery_implicit': {
79 'sweeper_params': sweeper_params,
80 'handling_params': handling_params,
81 'problem_params': params_battery_1capacitor,
82 },
83 'battery_n_capacitors': {
84 'sweeper_params': sweeper_params,
85 'handling_params': handling_params,
86 'problem_params': params_battery_2capacitors,
87 },
88 }
90 # ---- simulation domain for each problem class ----
91 interval = {
92 'battery': (0.0, 0.3),
93 'battery_implicit': (0.0, 0.3),
94 'battery_n_capacitors': (0.0, 0.5),
95 }
97 hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize]
99 use_detection = [True, False]
100 use_adaptivity = [True, False]
102 for problem, sweeper, prob_cls_name in zip(problem_classes, sweeper_classes, prob_class_names, strict=True):
103 u_num = runSimulation(
104 problem=problem,
105 sweeper=sweeper,
106 all_params=all_params[prob_cls_name],
107 use_adaptivity=use_adaptivity,
108 use_detection=use_detection,
109 hook_class=hook_class,
110 interval=interval[prob_cls_name],
111 dt_list=[1e-2, 1e-3],
112 nnodes=[M_fix],
113 )
115 plotAccuracyCheck(u_num, prob_cls_name, M_fix)
117 # plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix)
119 plotStateFunctionOverTime(u_num, prob_cls_name, M_fix)
122def plotAccuracyCheck(u_num, prob_cls_name, M_fix): # pragma: no cover
123 r"""
124 Routine to check accuracy for different step sizes in case of using adaptivity.
126 Parameters
127 ----------
128 u_num : dict
129 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``,
130 where for each step size ``dt``, for each number of collocation node ``M``, for each
131 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored.
132 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``.
133 prob_cls_name : str
134 Name of the problem class.
135 M_fix : int
136 Fixed number of collocation nodes the plot is generated for.
137 """
139 colors = plotStylingStuff()
140 dt_list = u_num.keys()
142 use_A = True
143 for dt in dt_list:
144 fig, ax = plt_helper.plt.subplots(1, 1, figsize=(7.5, 5), sharex='col', sharey='row')
145 e_ax = ax.twinx()
146 for use_SE in u_num[dt][M_fix].keys():
147 dt_val = u_num[dt][M_fix][use_SE][use_A]['dt']
148 e_em_val = u_num[dt][M_fix][use_SE][use_A]['e_em']
149 if use_SE:
150 t_switches = u_num[dt][M_fix][use_SE][use_A]['t_switches']
152 for i in range(len(t_switches)):
153 ax.axvline(x=t_switches[i], linestyle='--', color='tomato', label='Event {}'.format(i + 1))
155 ax.plot(dt_val[:, 0], dt_val[:, 1], color=colors[use_SE][use_A], label=r'SE={}, A={}'.format(use_SE, use_A))
157 e_ax.plot(e_em_val[:, 0], e_em_val[:, 1], linestyle='dashdot', color=colors[use_SE][use_A])
159 ax.plot(0, 0, color='black', linestyle='solid', label=r'$\Delta t_\mathrm{adapt}$')
160 ax.plot(0, 0, color='black', linestyle='dashdot', label=r'$e_{em}$')
162 e_ax.set_yscale('log', base=10)
163 e_ax.set_ylabel(r'Embedded error estimate $e_{em}$', fontsize=16)
164 e_ax.set_ylim(1e-16, 1e-7)
165 e_ax.tick_params(labelsize=16)
166 e_ax.minorticks_off()
168 ax.tick_params(axis='both', which='major', labelsize=16)
169 ax.set_ylim(1e-9, 1e0)
170 ax.set_yscale('log', base=10)
171 ax.set_xlabel(r'Time $t$', fontsize=16)
172 ax.set_ylabel(r'Adapted step sizes $\Delta t_\mathrm{adapt}$', fontsize=16)
173 ax.grid(visible=True)
174 ax.minorticks_off()
175 ax.legend(frameon=True, fontsize=12, loc='center left')
177 fig.savefig(
178 'data/detection_and_adaptivity_{}_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix),
179 dpi=300,
180 bbox_inches='tight',
181 )
182 plt_helper.plt.close(fig)
185def plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix): # pragma: no cover
186 r"""
187 Routine that plots the state function at time before the event, exactly at the event, and after the event. Note
188 that this routine does make sense only for a state function that remains constant after the event.
190 TODO: Function still does not work as expected. Every time when the switch estimator is adapted, the tolerances
191 does not suit anymore!
193 Parameters
194 ----------
195 u_num : dict
196 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``,
197 where for each step size ``dt``, for each number of collocation node ``M``, for each
198 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored.
199 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``.
200 prob_cls_name : str
201 Name of the problem class.
202 M_fix : int
203 Fixed number of collocation nodes the plot is generated for.
204 """
206 title_cases = {
207 0: 'Using detection',
208 1: 'Using adaptivity',
209 2: 'Using adaptivity and detection',
210 }
212 dt_list = list(u_num.keys())
213 use_detection = u_num[list(dt_list)[0]][M_fix].keys()
214 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys()
215 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function']
216 n = h0[0].shape[0]
218 for i in range(n):
219 fig, ax = plt_helper.plt.subplots(1, 3, figsize=(12, 4), sharex='col', sharey='row', squeeze=False)
220 dt_list = list(u_num.keys())
221 for use_SE in use_detection:
222 for use_A in use_adaptivity:
223 # ---- decide whether state function (w/o handling) has two entries or one; choose correct one with reshaping ----
224 h_val_no_handling = [u_num[dt][M_fix][False][False]['state_function'] for dt in dt_list]
225 h_no_handling = [item[:] if n == 1 else item[:, i] for item in h_val_no_handling]
226 h_no_handling = [item.reshape((item.shape[0],)) for item in h_no_handling]
228 t_no_handling = [u_num[dt][M_fix][False][False]['t'] for dt in dt_list]
230 if not use_A and not use_SE:
231 continue
232 else:
233 ind = 0 if (not use_A and use_SE) else (1 if (use_A and not use_SE) else 2)
234 ax[0, ind].set_title(r'{} for $n={}$'.format(title_cases[ind], i + 1))
236 # ---- same is done here for state function of other cases ----
237 h_val = [u_num[dt][M_fix][use_SE][use_A]['state_function'] for dt in dt_list]
238 h = [item[:] if n == 1 else item[:, i] for item in h_val]
239 h = [item.reshape((item.shape[0],)) for item in h]
241 t = [u_num[dt][M_fix][use_SE][use_A]['t'] for dt in dt_list]
243 if use_SE:
244 t_switches = [u_num[dt][M_fix][use_SE][use_A]['t_switches'] for dt in dt_list]
245 for t_switch_item in t_switches:
246 mask = np.append([True], np.abs(t_switch_item[1:] - t_switch_item[:-1]) > 1e-10)
247 t_switch_item = t_switch_item[mask]
249 t_switch = [t_event[i] for t_event in t_switches]
250 ax[0, ind].plot(
251 dt_list,
252 [
253 h_item[m]
254 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch, strict=True)
255 for m in range(len(t_item))
256 if abs(t_item[m] - t_switch_item) <= 2.7961188919789493e-11
257 ],
258 color='limegreen',
259 marker='s',
260 linestyle='solid',
261 alpha=0.4,
262 label='At event',
263 )
265 ax[0, ind].plot(
266 dt_list,
267 [
268 h_item[m - 1]
269 for (t_item, h_item, t_switch_item) in zip(
270 t_no_handling, h_no_handling, t_switch, strict=True
271 )
272 for m in range(1, len(t_item))
273 if t_item[m - 1] < t_switch_item < t_item[m]
274 ],
275 color='firebrick',
276 marker='o',
277 linestyle='solid',
278 alpha=0.4,
279 label='Before event',
280 )
282 ax[0, ind].plot(
283 dt_list,
284 [
285 h_item[m]
286 for (t_item, h_item, t_switch_item) in zip(
287 t_no_handling, h_no_handling, t_switch, strict=True
288 )
289 for m in range(1, len(t_item))
290 if t_item[m - 1] < t_switch_item < t_item[m]
291 ],
292 color='deepskyblue',
293 marker='*',
294 linestyle='solid',
295 alpha=0.4,
296 label='After event',
297 )
299 else:
300 ax[0, ind].plot(
301 dt_list,
302 [
303 h_item[m - 1]
304 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch, strict=True)
305 for m in range(1, len(t_item))
306 if t_item[m - 1] < t_switch_item < t_item[m]
307 ],
308 color='firebrick',
309 marker='o',
310 linestyle='solid',
311 alpha=0.4,
312 label='Before event',
313 )
315 ax[0, ind].plot(
316 dt_list,
317 [
318 h_item[m]
319 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch, strict=True)
320 for m in range(1, len(t_item))
321 if t_item[m - 1] < t_switch_item < t_item[m]
322 ],
323 color='deepskyblue',
324 marker='*',
325 linestyle='solid',
326 alpha=0.4,
327 label='After event',
328 )
330 ax[0, ind].tick_params(axis='both', which='major', labelsize=16)
331 ax[0, ind].set_xticks(dt_list)
332 ax[0, ind].set_xticklabels(dt_list)
333 ax[0, ind].set_ylim(1e-15, 1e1)
334 ax[0, ind].set_yscale('log', base=10)
335 ax[0, ind].set_xlabel(r'Step size $\Delta t$', fontsize=16)
336 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16)
337 ax[0, ind].grid(visible=True)
338 ax[0, ind].minorticks_off()
339 ax[0, ind].legend(frameon=True, fontsize=12, loc='lower left')
341 fig.savefig(
342 'data/{}_comparison_event{}_M={}.png'.format(prob_cls_name, i + 1, M_fix), dpi=300, bbox_inches='tight'
343 )
344 plt_helper.plt.close(fig)
347def plotStateFunctionOverTime(u_num, prob_cls_name, M_fix): # pragma: no cover
348 r"""
349 Routine that plots the state function over time.
351 Parameters
352 ----------
353 u_num : dict
354 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``,
355 where for each step size ``dt``, for each number of collocation node ``M``, for each
356 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored.
357 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``.
358 prob_cls_name : str
359 Indicates the name of the problem class to be considered.
360 M_fix : int
361 Fixed number of collocation nodes the plot is generated for.
362 """
364 dt_list = u_num.keys()
365 use_detection = u_num[list(dt_list)[0]][M_fix].keys()
366 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys()
367 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function']
368 n = h0[0].shape[0]
369 for dt in dt_list:
370 figsize = (7.5, 5) if n == 1 else (12, 5)
371 fig, ax = plt_helper.plt.subplots(1, n, figsize=figsize, sharex='col', sharey='row', squeeze=False)
373 for use_SE in use_detection:
374 for use_A in use_adaptivity:
375 t = u_num[dt][M_fix][use_SE][use_A]['t']
376 h_val = u_num[dt][M_fix][use_SE][use_A]['state_function']
378 linestyle = 'dashdot' if use_A else 'dotted'
379 for i in range(n):
380 h = h_val[:] if n == 1 else h_val[:, i]
381 ax[0, i].set_title(r'$n={}$'.format(i + 1))
382 ax[0, i].plot(
383 t, h, linestyle=linestyle, label='Detection: {}, Adaptivity: {}'.format(use_SE, use_A)
384 )
386 ax[0, i].tick_params(axis='both', which='major', labelsize=16)
387 ax[0, i].set_ylim(1e-15, 1e0)
388 ax[0, i].set_yscale('log', base=10)
389 ax[0, i].set_xlabel(r'Time $t$', fontsize=16)
390 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16)
391 ax[0, i].grid(visible=True)
392 ax[0, i].minorticks_off()
393 ax[0, i].legend(frameon=True, fontsize=12, loc='lower left')
395 fig.savefig(
396 'data/{}_state_function_over_time_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix),
397 dpi=300,
398 bbox_inches='tight',
399 )
400 plt_helper.plt.close(fig)
403if __name__ == "__main__":
404 run_estimation_check()