Coverage for pySDC/projects/PinTSimE/estimation_check.py: 100%
31 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import numpy as np
2from pathlib import Path
4from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted
5from pySDC.implementations.problem_classes.Battery import battery_n_capacitors, battery, battery_implicit
6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
9from pySDC.projects.PinTSimE.battery_model import runSimulation, plotStylingStuff
11import pySDC.helpers.plot_helper as plt_helper
13from pySDC.projects.PinTSimE.battery_model import LogEventBattery
14from pySDC.implementations.hooks.log_solution import LogSolution
15from pySDC.implementations.hooks.log_step_size import LogStepSize
16from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
19def run_estimation_check():
20 r"""
21 Generates plots to visualise results applying the Switch Estimator and Adaptivity to the battery models
22 containing.
24 Note
25 ----
26 Hardcoded solutions for battery models in `pySDC.projects.PinTSimE.hardcoded_solutions` are only computed for
27 ``dt_list=[1e-2, 1e-3]`` and ``M_fix=4``. Hence changing ``dt_list`` and ``M_fix`` to different values could arise
28 an ``AssertionError``.
29 """
31 Path("data").mkdir(parents=True, exist_ok=True)
33 # --- defines parameters for sweeper ----
34 M_fix = 4
35 sweeper_params = {
36 'num_nodes': M_fix,
37 'quad_type': 'LOBATTO',
38 'QI': 'IE',
39 }
41 # --- defines parameters for event detection and maximum number of iterations ----
42 handling_params = {
43 'restol': -1,
44 'maxiter': 8,
45 'max_restarts': 50,
46 'recomputed': False,
47 'tol_event': 1e-10,
48 'alpha': 0.96,
49 'exact_event_time_avail': None,
50 }
52 problem_classes = [battery, battery_implicit, battery_n_capacitors]
53 prob_class_names = [cls.__name__ for cls in problem_classes]
54 sweeper_classes = [imex_1st_order, generic_implicit, imex_1st_order]
56 # --- defines parameters for battery models ----
57 params_battery_1capacitor = {
58 'ncapacitors': 1,
59 'C': np.array([1.0]),
60 'alpha': 1.2,
61 'V_ref': np.array([1.0]),
62 }
64 params_battery_2capacitors = {
65 'ncapacitors': 2,
66 'C': np.array([1.0, 1.0]),
67 'alpha': 1.2,
68 'V_ref': np.array([1.0, 1.0]),
69 }
71 # --- parameters for each problem class are stored in this dictionary ----
72 all_params = {
73 'battery': {
74 'sweeper_params': sweeper_params,
75 'handling_params': handling_params,
76 'problem_params': params_battery_1capacitor,
77 },
78 'battery_implicit': {
79 'sweeper_params': sweeper_params,
80 'handling_params': handling_params,
81 'problem_params': params_battery_1capacitor,
82 },
83 'battery_n_capacitors': {
84 'sweeper_params': sweeper_params,
85 'handling_params': handling_params,
86 'problem_params': params_battery_2capacitors,
87 },
88 }
90 # ---- simulation domain for each problem class ----
91 interval = {
92 'battery': (0.0, 0.3),
93 'battery_implicit': (0.0, 0.3),
94 'battery_n_capacitors': (0.0, 0.5),
95 }
97 hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize]
99 use_detection = [True, False]
100 use_adaptivity = [True, False]
102 for problem, sweeper, prob_cls_name in zip(problem_classes, sweeper_classes, prob_class_names):
103 u_num = runSimulation(
104 problem=problem,
105 sweeper=sweeper,
106 all_params=all_params[prob_cls_name],
107 use_adaptivity=use_adaptivity,
108 use_detection=use_detection,
109 hook_class=hook_class,
110 interval=interval[prob_cls_name],
111 dt_list=[1e-2, 1e-3],
112 nnodes=[M_fix],
113 )
115 plotAccuracyCheck(u_num, prob_cls_name, M_fix)
117 # plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix)
119 plotStateFunctionOverTime(u_num, prob_cls_name, M_fix)
122def plotAccuracyCheck(u_num, prob_cls_name, M_fix): # pragma: no cover
123 r"""
124 Routine to check accuracy for different step sizes in case of using adaptivity.
126 Parameters
127 ----------
128 u_num : dict
129 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``,
130 where for each step size ``dt``, for each number of collocation node ``M``, for each
131 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored.
132 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``.
133 prob_cls_name : str
134 Name of the problem class.
135 M_fix : int
136 Fixed number of collocation nodes the plot is generated for.
137 """
139 colors = plotStylingStuff()
140 dt_list = u_num.keys()
142 use_A = True
143 for dt in dt_list:
144 fig, ax = plt_helper.plt.subplots(1, 1, figsize=(7.5, 5), sharex='col', sharey='row')
145 e_ax = ax.twinx()
146 for use_SE in u_num[dt][M_fix].keys():
147 dt_val = u_num[dt][M_fix][use_SE][use_A]['dt']
148 e_em_val = u_num[dt][M_fix][use_SE][use_A]['e_em']
149 if use_SE:
150 t_switches = u_num[dt][M_fix][use_SE][use_A]['t_switches']
152 for i in range(len(t_switches)):
153 ax.axvline(x=t_switches[i], linestyle='--', color='tomato', label='Event {}'.format(i + 1))
155 ax.plot(dt_val[:, 0], dt_val[:, 1], color=colors[use_SE][use_A], label=r'SE={}, A={}'.format(use_SE, use_A))
157 e_ax.plot(e_em_val[:, 0], e_em_val[:, 1], linestyle='dashdot', color=colors[use_SE][use_A])
159 ax.plot(0, 0, color='black', linestyle='solid', label=r'$\Delta t_\mathrm{adapt}$')
160 ax.plot(0, 0, color='black', linestyle='dashdot', label=r'$e_{em}$')
162 e_ax.set_yscale('log', base=10)
163 e_ax.set_ylabel(r'Embedded error estimate $e_{em}$', fontsize=16)
164 e_ax.set_ylim(1e-16, 1e-7)
165 e_ax.tick_params(labelsize=16)
166 e_ax.minorticks_off()
168 ax.tick_params(axis='both', which='major', labelsize=16)
169 ax.set_ylim(1e-9, 1e0)
170 ax.set_yscale('log', base=10)
171 ax.set_xlabel(r'Time $t$', fontsize=16)
172 ax.set_ylabel(r'Adapted step sizes $\Delta t_\mathrm{adapt}$', fontsize=16)
173 ax.grid(visible=True)
174 ax.minorticks_off()
175 ax.legend(frameon=True, fontsize=12, loc='center left')
177 fig.savefig(
178 'data/detection_and_adaptivity_{}_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix),
179 dpi=300,
180 bbox_inches='tight',
181 )
182 plt_helper.plt.close(fig)
185def plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix): # pragma: no cover
186 r"""
187 Routine that plots the state function at time before the event, exactly at the event, and after the event. Note
188 that this routine does make sense only for a state function that remains constant after the event.
190 TODO: Function still does not work as expected. Every time when the switch estimator is adapted, the tolerances
191 does not suit anymore!
193 Parameters
194 ----------
195 u_num : dict
196 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``,
197 where for each step size ``dt``, for each number of collocation node ``M``, for each
198 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored.
199 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``.
200 prob_cls_name : str
201 Name of the problem class.
202 M_fix : int
203 Fixed number of collocation nodes the plot is generated for.
204 """
206 title_cases = {
207 0: 'Using detection',
208 1: 'Using adaptivity',
209 2: 'Using adaptivity and detection',
210 }
212 dt_list = list(u_num.keys())
213 use_detection = u_num[list(dt_list)[0]][M_fix].keys()
214 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys()
215 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function']
216 n = h0[0].shape[0]
218 for i in range(n):
219 fig, ax = plt_helper.plt.subplots(1, 3, figsize=(12, 4), sharex='col', sharey='row', squeeze=False)
220 dt_list = list(u_num.keys())
221 for use_SE in use_detection:
222 for use_A in use_adaptivity:
223 # ---- decide whether state function (w/o handling) has two entries or one; choose correct one with reshaping ----
224 h_val_no_handling = [u_num[dt][M_fix][False][False]['state_function'] for dt in dt_list]
225 h_no_handling = [item[:] if n == 1 else item[:, i] for item in h_val_no_handling]
226 h_no_handling = [item.reshape((item.shape[0],)) for item in h_no_handling]
228 t_no_handling = [u_num[dt][M_fix][False][False]['t'] for dt in dt_list]
230 if not use_A and not use_SE:
231 continue
232 else:
233 ind = 0 if (not use_A and use_SE) else (1 if (use_A and not use_SE) else 2)
234 ax[0, ind].set_title(r'{} for $n={}$'.format(title_cases[ind], i + 1))
236 # ---- same is done here for state function of other cases ----
237 h_val = [u_num[dt][M_fix][use_SE][use_A]['state_function'] for dt in dt_list]
238 h = [item[:] if n == 1 else item[:, i] for item in h_val]
239 h = [item.reshape((item.shape[0],)) for item in h]
241 t = [u_num[dt][M_fix][use_SE][use_A]['t'] for dt in dt_list]
243 if use_SE:
244 t_switches = [u_num[dt][M_fix][use_SE][use_A]['t_switches'] for dt in dt_list]
245 for t_switch_item in t_switches:
246 mask = np.append([True], np.abs(t_switch_item[1:] - t_switch_item[:-1]) > 1e-10)
247 t_switch_item = t_switch_item[mask]
249 t_switch = [t_event[i] for t_event in t_switches]
250 ax[0, ind].plot(
251 dt_list,
252 [
253 h_item[m]
254 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch)
255 for m in range(len(t_item))
256 if abs(t_item[m] - t_switch_item) <= 2.7961188919789493e-11
257 ],
258 color='limegreen',
259 marker='s',
260 linestyle='solid',
261 alpha=0.4,
262 label='At event',
263 )
265 ax[0, ind].plot(
266 dt_list,
267 [
268 h_item[m - 1]
269 for (t_item, h_item, t_switch_item) in zip(t_no_handling, h_no_handling, t_switch)
270 for m in range(1, len(t_item))
271 if t_item[m - 1] < t_switch_item < t_item[m]
272 ],
273 color='firebrick',
274 marker='o',
275 linestyle='solid',
276 alpha=0.4,
277 label='Before event',
278 )
280 ax[0, ind].plot(
281 dt_list,
282 [
283 h_item[m]
284 for (t_item, h_item, t_switch_item) in zip(t_no_handling, h_no_handling, t_switch)
285 for m in range(1, len(t_item))
286 if t_item[m - 1] < t_switch_item < t_item[m]
287 ],
288 color='deepskyblue',
289 marker='*',
290 linestyle='solid',
291 alpha=0.4,
292 label='After event',
293 )
295 else:
296 ax[0, ind].plot(
297 dt_list,
298 [
299 h_item[m - 1]
300 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch)
301 for m in range(1, len(t_item))
302 if t_item[m - 1] < t_switch_item < t_item[m]
303 ],
304 color='firebrick',
305 marker='o',
306 linestyle='solid',
307 alpha=0.4,
308 label='Before event',
309 )
311 ax[0, ind].plot(
312 dt_list,
313 [
314 h_item[m]
315 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch)
316 for m in range(1, len(t_item))
317 if t_item[m - 1] < t_switch_item < t_item[m]
318 ],
319 color='deepskyblue',
320 marker='*',
321 linestyle='solid',
322 alpha=0.4,
323 label='After event',
324 )
326 ax[0, ind].tick_params(axis='both', which='major', labelsize=16)
327 ax[0, ind].set_xticks(dt_list)
328 ax[0, ind].set_xticklabels(dt_list)
329 ax[0, ind].set_ylim(1e-15, 1e1)
330 ax[0, ind].set_yscale('log', base=10)
331 ax[0, ind].set_xlabel(r'Step size $\Delta t$', fontsize=16)
332 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16)
333 ax[0, ind].grid(visible=True)
334 ax[0, ind].minorticks_off()
335 ax[0, ind].legend(frameon=True, fontsize=12, loc='lower left')
337 fig.savefig(
338 'data/{}_comparison_event{}_M={}.png'.format(prob_cls_name, i + 1, M_fix), dpi=300, bbox_inches='tight'
339 )
340 plt_helper.plt.close(fig)
343def plotStateFunctionOverTime(u_num, prob_cls_name, M_fix): # pragma: no cover
344 r"""
345 Routine that plots the state function over time.
347 Parameters
348 ----------
349 u_num : dict
350 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``,
351 where for each step size ``dt``, for each number of collocation node ``M``, for each
352 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored.
353 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``.
354 prob_cls_name : str
355 Indicates the name of the problem class to be considered.
356 M_fix : int
357 Fixed number of collocation nodes the plot is generated for.
358 """
360 dt_list = u_num.keys()
361 use_detection = u_num[list(dt_list)[0]][M_fix].keys()
362 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys()
363 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function']
364 n = h0[0].shape[0]
365 for dt in dt_list:
366 figsize = (7.5, 5) if n == 1 else (12, 5)
367 fig, ax = plt_helper.plt.subplots(1, n, figsize=figsize, sharex='col', sharey='row', squeeze=False)
369 for use_SE in use_detection:
370 for use_A in use_adaptivity:
371 t = u_num[dt][M_fix][use_SE][use_A]['t']
372 h_val = u_num[dt][M_fix][use_SE][use_A]['state_function']
374 linestyle = 'dashdot' if use_A else 'dotted'
375 for i in range(n):
376 h = h_val[:] if n == 1 else h_val[:, i]
377 ax[0, i].set_title(r'$n={}$'.format(i + 1))
378 ax[0, i].plot(
379 t, h, linestyle=linestyle, label='Detection: {}, Adaptivity: {}'.format(use_SE, use_A)
380 )
382 ax[0, i].tick_params(axis='both', which='major', labelsize=16)
383 ax[0, i].set_ylim(1e-15, 1e0)
384 ax[0, i].set_yscale('log', base=10)
385 ax[0, i].set_xlabel(r'Time $t$', fontsize=16)
386 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16)
387 ax[0, i].grid(visible=True)
388 ax[0, i].minorticks_off()
389 ax[0, i].legend(frameon=True, fontsize=12, loc='lower left')
391 fig.savefig(
392 'data/{}_state_function_over_time_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix),
393 dpi=300,
394 bbox_inches='tight',
395 )
396 plt_helper.plt.close(fig)
399if __name__ == "__main__":
400 run_estimation_check()