Coverage for pySDC/projects/PinTSimE/battery_model.py: 100%
126 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
2from pathlib import Path
4from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted
5from pySDC.implementations.problem_classes.Battery import battery, battery_implicit, battery_n_capacitors
6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
8from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
10import pySDC.helpers.plot_helper as plt_helper
12from pySDC.core.hooks import Hooks
13from pySDC.implementations.hooks.log_solution import LogSolution
14from pySDC.implementations.hooks.log_step_size import LogStepSize
15from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
17from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
18from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
19from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI
21from pySDC.projects.PinTSimE.hardcoded_solutions import testSolution
24class LogEventBattery(Hooks):
25 """
26 Logs the problem dependent state function of the battery drain model.
27 """
29 def post_step(self, step, level_number):
30 super().post_step(step, level_number)
32 L = step.levels[level_number]
33 P = L.prob
35 L.sweep.compute_end_point()
37 self.add_to_stats(
38 process=step.status.slot,
39 time=L.time + L.dt,
40 level=L.level_index,
41 iter=0,
42 sweep=L.status.sweep,
43 type='state_function',
44 value=L.uend[1:] - P.V_ref[:],
45 )
48def generateDescription(
49 dt,
50 problem,
51 sweeper,
52 num_nodes,
53 quad_type,
54 QI,
55 hook_class,
56 use_adaptivity,
57 use_switch_estimator,
58 problem_params,
59 restol,
60 maxiter,
61 max_restarts=None,
62 tol_event=1e-10,
63 alpha=1.0,
64):
65 r"""
66 Generate a description for the battery models for a controller run.
68 Parameters
69 ----------
70 dt : float
71 Time step for computation.
72 problem : pySDC.core.Problem
73 Problem class that wants to be simulated.
74 sweeper : pySDC.core.Sweeper
75 Sweeper class for solving the problem class numerically.
76 num_nodes : int
77 Number of collocation nodes.
78 quad_type : str
79 Type of quadrature nodes, e.g. ``'LOBATTO'`` or ``'RADAU-RIGHT'``.
80 QI : str
81 Type of preconditioner used in SDC, e.g. ``'IE'`` or ``'LU'``.
82 hook_class : List of pySDC.core.Hooks
83 Logged data for a problem, e.g., hook_class=[LogSolution] logs the solution ``'u'``
84 during the simulation.
85 use_adaptivity : bool
86 Flag if the adaptivity wants to be used or not.
87 use_switch_estimator : bool
88 Flag if the switch estimator wants to be used or not.
89 problem_params : dict
90 Dictionary containing the problem parameters.
91 restol : float
92 Residual tolerance to terminate.
93 maxiter : int
94 Maximum number of iterations to be done.
95 max_restarts : int, optional
96 Maximum number of restarts per step.
97 tol_event : float, optional
98 Tolerance for event detection to terminate.
99 alpha : float, optional
100 Factor that indicates how the new step size in the Switch Estimator is reduced.
102 Returns
103 -------
104 description : dict
105 Contains all information for a controller run.
106 controller_params : dict
107 Parameters needed for a controller run.
108 """
110 # initialize level parameters
111 level_params = {
112 'restol': -1 if use_adaptivity else restol,
113 'dt': dt,
114 }
115 if use_adaptivity:
116 assert restol == -1, "Please set restol to -1 or omit it"
118 # initialize sweeper parameters
119 sweeper_params = {
120 'quad_type': quad_type,
121 'num_nodes': num_nodes,
122 'QI': QI,
123 'initial_guess': 'spread',
124 }
126 # initialize step parameters
127 step_params = {
128 'maxiter': maxiter,
129 }
130 assert 'errtol' not in step_params.keys(), 'No exact solution known to compute error'
132 # initialize controller parameters
133 controller_params = {
134 'logger_level': 30,
135 'hook_class': hook_class,
136 'mssdc_jac': False,
137 }
139 # convergence controllers
140 convergence_controllers = {}
141 if use_switch_estimator:
142 switch_estimator_params = {
143 'tol': tol_event,
144 'alpha': alpha,
145 }
146 convergence_controllers.update({SwitchEstimator: switch_estimator_params})
147 if use_adaptivity:
148 adaptivity_params = {
149 'e_tol': 1e-7,
150 }
151 convergence_controllers.update({Adaptivity: adaptivity_params})
152 if max_restarts is not None:
153 restarting_params = {
154 'max_restarts': max_restarts,
155 'crash_after_max_restarts': False,
156 }
157 convergence_controllers.update({BasicRestartingNonMPI: restarting_params})
159 # fill description dictionary for easy step instantiation
160 description = {
161 'problem_class': problem,
162 'problem_params': problem_params,
163 'sweeper_class': sweeper,
164 'sweeper_params': sweeper_params,
165 'level_params': level_params,
166 'step_params': step_params,
167 'convergence_controllers': convergence_controllers,
168 }
170 # instantiate controller
171 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
173 return description, controller_params, controller
176def controllerRun(description, controller_params, controller, t0, Tend, exact_event_time_avail=False):
177 """
178 Executes a controller run for a problem defined in the description.
180 Parameters
181 ----------
182 description : dict
183 Contains all information for a controller run.
184 controller_params : dict
185 Parameters needed for a controller run.
186 controller : pySDC.core.Controller
187 Controller to do the stuff.
188 t0 : float
189 Starting time of simulation.
190 Tend : float
191 End time of simulation.
192 exact_event_time_avail : bool, optional
193 Indicates if exact event time of a problem is available.
195 Returns
196 -------
197 stats : dict
198 Raw statistics from a controller run.
199 """
201 # get initial values on finest level
202 P = controller.MS[0].levels[0].prob
203 uinit = P.u_exact(t0)
204 t_switch_exact = P.t_switch_exact if exact_event_time_avail else None
206 # call main function to get things done...
207 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
209 return stats, t_switch_exact
212def main():
213 r"""
214 Executes the simulation.
216 Note
217 ----
218 Hardcoded solutions for battery models in `pySDC.projects.PinTSimE.hardcoded_solutions` are only computed for
219 ``dt_list=[1e-2, 1e-3]`` and ``M_fix=4``. Hence changing ``dt_list`` and ``M_fix`` to different values could arise
220 an ``AssertionError``.
221 """
223 # defines parameters for sweeper
224 M_fix = 4
225 sweeper_params = {
226 'num_nodes': M_fix,
227 'quad_type': 'LOBATTO',
228 'QI': 'IE',
229 }
231 # defines parameters for event detection, restol, and max. number of iterations
232 handling_params = {
233 'restol': -1,
234 'maxiter': 8,
235 'max_restarts': 50,
236 'recomputed': False,
237 'tol_event': 1e-10,
238 'alpha': 0.96,
239 'exact_event_time_avail': None,
240 }
242 all_params = {
243 'sweeper_params': sweeper_params,
244 'handling_params': handling_params,
245 }
247 hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize]
249 use_detection = [True, False]
250 use_adaptivity = [True, False]
252 for problem, sweeper in zip([battery, battery_implicit], [imex_1st_order, generic_implicit]):
253 for defaults in [False, True]:
254 # for hardcoded solutions problem parameter defaults should match with parameters here
255 if defaults:
256 params_battery_1capacitor = {
257 'ncapacitors': 1,
258 }
259 else:
260 params_battery_1capacitor = {
261 'ncapacitors': 1,
262 'C': np.array([1.0]),
263 'alpha': 1.2,
264 'V_ref': np.array([1.0]),
265 }
267 all_params.update({'problem_params': params_battery_1capacitor})
269 _ = runSimulation(
270 problem=problem,
271 sweeper=sweeper,
272 all_params=all_params,
273 use_adaptivity=use_adaptivity,
274 use_detection=use_detection,
275 hook_class=hook_class,
276 interval=(0.0, 0.3),
277 dt_list=[1e-2, 1e-3],
278 nnodes=[M_fix],
279 )
281 # defines parameters for the problem class
282 params_battery_2capacitors = {
283 'ncapacitors': 2,
284 'C': np.array([1.0, 1.0]),
285 'alpha': 1.2,
286 'V_ref': np.array([1.0, 1.0]),
287 }
289 all_params.update({'problem_params': params_battery_2capacitors})
291 _ = runSimulation(
292 problem=battery_n_capacitors,
293 sweeper=imex_1st_order,
294 all_params=all_params,
295 use_adaptivity=use_adaptivity,
296 use_detection=use_detection,
297 hook_class=hook_class,
298 interval=(0.0, 0.5),
299 dt_list=[1e-2, 1e-3],
300 nnodes=[sweeper_params['num_nodes']],
301 )
304def runSimulation(problem, sweeper, all_params, use_adaptivity, use_detection, hook_class, interval, dt_list, nnodes):
305 r"""
306 Script that executes the simulation for a given problem class for given parameters defined by the user.
308 Parameters
309 ----------
310 problem : pySDC.core.Problem
311 Problem class to be simulated.
312 sweeper : pySDC.core.Sweeper
313 Sweeper that is used to simulate the problem class.
314 all_params : dict
315 Dictionary contains the problem parameters for ``problem``, the sweeper parameters for ``sweeper``,
316 and handling parameters needed for event detection, i.e., ``max_restarts``, ``recomputed``, ``tol_event``,
317 ``alpha``, and ``exact_event_time_available``.
318 use_adaptivity : list of bool
319 Indicates whether adaptivity is used in the simulation or not. Here a list is used to iterate over the
320 different cases, i.e., ``use_adaptivity=[True, False]``.
321 use_detection : list of bool
322 Indicates whether event detection is used in the simulation or not. Here a list is used to iterate over the
323 different cases, i.e., ``use_detection=[True, False]``.
324 hook_class : list of pySDC.core.Hooks
325 List containing the different hook classes to log data during the simulation, i.e., ``hook_class=[LogSolution]``
326 logs the solution ``u``.
327 interval : tuple
328 Simulation interval.
329 dt_list : list of float
330 List containing different step sizes where the solution is computed.
331 nnodes : list of int
332 The solution can be computed for different number of collocation nodes.
333 """
335 Path("data").mkdir(parents=True, exist_ok=True)
337 prob_cls_name = problem.__name__
339 u_num = {}
341 for dt in dt_list:
342 u_num[dt] = {}
344 for M in nnodes:
345 u_num[dt][M] = {}
347 for use_SE in use_detection:
348 u_num[dt][M][use_SE] = {}
350 for use_A in use_adaptivity:
351 u_num[dt][M][use_SE][use_A] = {}
353 problem_params = all_params['problem_params']
354 sweeper_params = all_params['sweeper_params']
355 handling_params = all_params['handling_params']
357 # plotting results for fixed M requires that M_fix is included in nnodes!
358 M_fix = sweeper_params['num_nodes']
359 assert (
360 M_fix in nnodes
361 ), f"For fixed number of collocation nodes {M_fix} no solution will be computed!"
363 restol = -1 if use_A else handling_params['restol']
365 description, controller_params, controller = generateDescription(
366 dt=dt,
367 problem=problem,
368 sweeper=sweeper,
369 num_nodes=M,
370 quad_type=sweeper_params['quad_type'],
371 QI=sweeper_params['QI'],
372 hook_class=hook_class,
373 use_adaptivity=use_A,
374 use_switch_estimator=use_SE,
375 problem_params=problem_params,
376 restol=restol,
377 maxiter=handling_params['maxiter'],
378 max_restarts=handling_params['max_restarts'],
379 tol_event=handling_params['tol_event'],
380 alpha=handling_params['alpha'],
381 )
383 stats, t_switch_exact = controllerRun(
384 description=description,
385 controller_params=controller_params,
386 controller=controller,
387 t0=interval[0],
388 Tend=interval[-1],
389 exact_event_time_avail=handling_params['exact_event_time_avail'],
390 )
392 u_num[dt][M][use_SE][use_A] = getDataDict(
393 stats, prob_cls_name, use_A, use_SE, handling_params['recomputed'], t_switch_exact
394 )
396 plotSolution(u_num[dt][M][use_SE][use_A], prob_cls_name, use_A, use_SE)
398 testSolution(u_num[dt][M_fix][use_SE][use_A], prob_cls_name, dt, use_A, use_SE)
400 return u_num
403def getUnknownLabels(prob_cls_name):
404 """
405 Returns the unknown for a problem and corresponding labels for a plot.
407 Parameters
408 ----------
409 prob_cls_name : str
410 Name of the problem class.
412 Returns
413 -------
414 unknowns : list of str
415 Contains the names of unknowns.
416 unknowns_labels : list of str
417 Contains the labels of unknowns for plotting.
418 """
420 unknowns = {
421 'battery': ['iL', 'vC'],
422 'battery_implicit': ['iL', 'vC'],
423 'battery_n_capacitors': ['iL', 'vC1', 'vC2'],
424 'DiscontinuousTestODE': ['u'],
425 'piline': ['vC1', 'vC2', 'iLp'],
426 'buck_converter': ['vC1', 'vC2', 'iLp'],
427 }
429 unknowns_labels = {
430 'battery': [r'$i_L$', r'$v_C$'],
431 'battery_implicit': [r'$i_L$', r'$v_C$'],
432 'battery_n_capacitors': [r'$i_L$', r'$v_{C_1}$', r'$v_{C_2}$'],
433 'DiscontinuousTestODE': [r'$u$'],
434 'piline': [r'$v_{C_1}$', r'$v_{C_2}$', r'$i_{L_\pi}$'],
435 'buck_converter': [r'$v_{C_1}$', r'$v_{C_2}$', r'$i_{L_\pi}$'],
436 }
438 return unknowns[prob_cls_name], unknowns_labels[prob_cls_name]
441def plotStylingStuff(): # pragma: no cover
442 """
443 Returns plot stuff such as colors, line styles for making plots more pretty.
444 """
446 colors = {
447 False: {
448 False: 'dodgerblue',
449 True: 'navy',
450 },
451 True: {
452 False: 'linegreen',
453 True: 'darkgreen',
454 },
455 }
457 return colors
460def plotSolution(u_num, prob_cls_name, use_adaptivity, use_detection): # pragma: no cover
461 r"""
462 Plots the numerical solution for one simulation run.
464 Parameters
465 ----------
466 u_num : dict
467 Contains numerical solution with corresponding times for different problem_classes, and
468 labels for different unknowns of the problem.
469 prob_cls_name : str
470 Name of the problem class to be plotted.
471 use_adaptivity : bool
472 Indicates whether adaptivity is used in the simulation or not.
473 """
475 fig, ax = plt_helper.plt.subplots(1, 1, figsize=(7.5, 5))
477 unknowns = u_num['unknowns']
478 unknowns_labels = u_num['unknowns_labels']
479 for unknown, unknown_label in zip(unknowns, unknowns_labels):
480 ax.plot(u_num['t'], u_num[unknown], label=unknown_label)
482 if use_detection:
483 t_switches = u_num['t_switches']
484 for i in range(len(t_switches)):
485 ax.axvline(x=t_switches[i], linestyle='--', linewidth=0.8, color='r', label='Event {}'.format(i + 1))
487 if use_adaptivity:
488 dt_ax = ax.twinx()
489 dt = u_num['dt']
490 dt_ax.plot(dt[:, 0], dt[:, 1], linestyle='-', linewidth=0.8, color='k', label=r'$\Delta t$')
491 dt_ax.set_ylabel(r'$\Delta t$', fontsize=16)
492 dt_ax.legend(frameon=False, fontsize=12, loc='center right')
494 ax.legend(frameon=False, fontsize=12, loc='upper right')
495 ax.set_xlabel(r'$t$', fontsize=16)
496 ax.set_ylabel(r'$u(t)$', fontsize=16)
498 fig.savefig(f'data/{prob_cls_name}_model_solution.png', dpi=300, bbox_inches='tight')
499 plt_helper.plt.close(fig)
502def getDataDict(stats, prob_cls_name, use_adaptivity, use_detection, recomputed, t_switch_exact):
503 r"""
504 Extracts statistics and store it in a dictionary. In this routine, from ``stats`` different data are extracted
505 such as
507 - each component of solution ``'u'`` and corresponding time domain ``'t'``,
508 - the unknowns of the problem ``'unknowns'``,
509 - the unknowns of the problem as labels for plotting ``'unknowns_labels'``,
510 - global error ``'e_global'`` after each step,
511 - events found by event detection ``'t_switches''``,
512 - exact event time ``'t_switch_exact'``,
513 - event error ``'e_event'``,
514 - state function ``'state_function'``,
515 - embedded error estimate computing when using adaptivity ``'e_em'``,
516 - (adjusted) step sizes ``'dt'``,
517 - sum over restarts ``'sum_restarts'``,
518 - and the sum over all iterations ``'sum_niters'``.
520 Note
521 ----
522 In order to use these data, corresponding hook classes has to be defined before the simulation. Otherwise, no values can
523 be obtained.
525 The global error does only make sense when an exact solution for the problem is available. Since ``'e_global'`` is stored
526 for each problem class, only for ``DiscontinuousTestODE`` the global error is taken into account when testing the solution.
528 Also the event error ``'e_event'`` can only be computed if an exact event time is available. Since the function
529 ``controllerRun`` returns ``t_switch_exact=None`` when no exact event time is available, in order to compute the event error,
530 it has to be proven whether the list (in case of more than one event) contains ``None`` or not.
532 Parameters
533 ----------
534 stats : dict
535 Raw statistics of one simulation run.
536 prob_cls_name : str
537 Name of the problem class.
538 use_adaptivity : bool
539 Indicates whether adaptivity is used in the simulation or not.
540 use_detection : bool
541 Indicates whether event detection is used or not.
542 recomputed : bool
543 Indicates if values after successfully steps are used or not.
544 t_switch_exact : float
545 Exact event time of the problem.
547 Returns
548 -------
549 res : dict
550 Dictionary with extracted data separated with reasonable keys.
551 """
553 res = {}
554 unknowns, unknowns_labels = getUnknownLabels(prob_cls_name)
556 # numerical solution
557 u_val = get_sorted(stats, type='u', sortby='time', recomputed=recomputed)
558 res['t'] = np.array([item[0] for item in u_val])
559 for i, label in enumerate(unknowns):
560 res[label] = np.array([item[1][i] for item in u_val])
562 res['unknowns'] = unknowns
563 res['unknowns_labels'] = unknowns_labels
565 # global error
566 res['e_global'] = np.array(get_sorted(stats, type='e_global_post_step', sortby='time', recomputed=recomputed))
568 # event time(s) found by event detection
569 if use_detection:
570 switches = get_sorted(stats, type='switch', sortby='time', recomputed=recomputed)
571 assert len(switches) >= 1, 'No events found!'
572 t_switches = [t[1] for t in switches]
573 res['t_switches'] = t_switches
575 t_switch_exact = [t_switch_exact]
576 res['t_switch_exact'] = t_switch_exact
578 if not all(t is None for t in t_switch_exact):
579 event_err = [
580 abs(num_item - ex_item) for (num_item, ex_item) in zip(res['t_switches'], res['t_switch_exact'])
581 ]
582 res['e_event'] = event_err
584 h_val = get_sorted(stats, type='state_function', sortby='time', recomputed=recomputed)
585 h = np.array([np.abs(val[1]) for val in h_val])
586 res['state_function'] = h
588 # embedded error and adapted step sizes
589 if use_adaptivity:
590 res['e_em'] = np.array(get_sorted(stats, type='error_embedded_estimate', sortby='time', recomputed=recomputed))
591 res['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed))
593 # sum over restarts
594 if use_adaptivity or use_detection:
595 res['sum_restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])
597 # sum over all iterations
598 res['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1])
599 return res
602if __name__ == "__main__":
603 main()