Coverage for pySDC/projects/PinTSimE/battery_model.py: 100%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1import numpy as np 

2from pathlib import Path 

3 

4from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted 

5from pySDC.implementations.problem_classes.Battery import battery, battery_implicit, battery_n_capacitors 

6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

8from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

9 

10import pySDC.helpers.plot_helper as plt_helper 

11 

12from pySDC.core.hooks import Hooks 

13from pySDC.implementations.hooks.log_solution import LogSolution 

14from pySDC.implementations.hooks.log_step_size import LogStepSize 

15from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate 

16 

17from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator 

18from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

19from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI 

20 

21from pySDC.projects.PinTSimE.hardcoded_solutions import testSolution 

22 

23 

24class LogEventBattery(Hooks): 

25 """ 

26 Logs the problem dependent state function of the battery drain model. 

27 """ 

28 

29 def post_step(self, step, level_number): 

30 super().post_step(step, level_number) 

31 

32 L = step.levels[level_number] 

33 P = L.prob 

34 

35 L.sweep.compute_end_point() 

36 

37 self.add_to_stats( 

38 process=step.status.slot, 

39 time=L.time + L.dt, 

40 level=L.level_index, 

41 iter=0, 

42 sweep=L.status.sweep, 

43 type='state_function', 

44 value=L.uend[1:] - P.V_ref[:], 

45 ) 

46 

47 

48def generateDescription( 

49 dt, 

50 problem, 

51 sweeper, 

52 num_nodes, 

53 quad_type, 

54 QI, 

55 hook_class, 

56 use_adaptivity, 

57 use_switch_estimator, 

58 problem_params, 

59 restol, 

60 maxiter, 

61 max_restarts=None, 

62 tol_event=1e-10, 

63 alpha=1.0, 

64): 

65 r""" 

66 Generate a description for the battery models for a controller run. 

67 

68 Parameters 

69 ---------- 

70 dt : float 

71 Time step for computation. 

72 problem : pySDC.core.Problem 

73 Problem class that wants to be simulated. 

74 sweeper : pySDC.core.Sweeper 

75 Sweeper class for solving the problem class numerically. 

76 num_nodes : int 

77 Number of collocation nodes. 

78 quad_type : str 

79 Type of quadrature nodes, e.g. ``'LOBATTO'`` or ``'RADAU-RIGHT'``. 

80 QI : str 

81 Type of preconditioner used in SDC, e.g. ``'IE'`` or ``'LU'``. 

82 hook_class : List of pySDC.core.Hooks 

83 Logged data for a problem, e.g., hook_class=[LogSolution] logs the solution ``'u'`` 

84 during the simulation. 

85 use_adaptivity : bool 

86 Flag if the adaptivity wants to be used or not. 

87 use_switch_estimator : bool 

88 Flag if the switch estimator wants to be used or not. 

89 problem_params : dict 

90 Dictionary containing the problem parameters. 

91 restol : float 

92 Residual tolerance to terminate. 

93 maxiter : int 

94 Maximum number of iterations to be done. 

95 max_restarts : int, optional 

96 Maximum number of restarts per step. 

97 tol_event : float, optional 

98 Tolerance for event detection to terminate. 

99 alpha : float, optional 

100 Factor that indicates how the new step size in the Switch Estimator is reduced. 

101 

102 Returns 

103 ------- 

104 description : dict 

105 Contains all information for a controller run. 

106 controller_params : dict 

107 Parameters needed for a controller run. 

108 """ 

109 

110 # initialize level parameters 

111 level_params = { 

112 'restol': -1 if use_adaptivity else restol, 

113 'dt': dt, 

114 } 

115 if use_adaptivity: 

116 assert restol == -1, "Please set restol to -1 or omit it" 

117 

118 # initialize sweeper parameters 

119 sweeper_params = { 

120 'quad_type': quad_type, 

121 'num_nodes': num_nodes, 

122 'QI': QI, 

123 'initial_guess': 'spread', 

124 } 

125 

126 # initialize step parameters 

127 step_params = { 

128 'maxiter': maxiter, 

129 } 

130 assert 'errtol' not in step_params.keys(), 'No exact solution known to compute error' 

131 

132 # initialize controller parameters 

133 controller_params = { 

134 'logger_level': 30, 

135 'hook_class': hook_class, 

136 'mssdc_jac': False, 

137 } 

138 

139 # convergence controllers 

140 convergence_controllers = {} 

141 if use_switch_estimator: 

142 switch_estimator_params = { 

143 'tol': tol_event, 

144 'alpha': alpha, 

145 } 

146 convergence_controllers.update({SwitchEstimator: switch_estimator_params}) 

147 if use_adaptivity: 

148 adaptivity_params = { 

149 'e_tol': 1e-7, 

150 } 

151 convergence_controllers.update({Adaptivity: adaptivity_params}) 

152 if max_restarts is not None: 

153 restarting_params = { 

154 'max_restarts': max_restarts, 

155 'crash_after_max_restarts': False, 

156 } 

157 convergence_controllers.update({BasicRestartingNonMPI: restarting_params}) 

158 

159 # fill description dictionary for easy step instantiation 

160 description = { 

161 'problem_class': problem, 

162 'problem_params': problem_params, 

163 'sweeper_class': sweeper, 

164 'sweeper_params': sweeper_params, 

165 'level_params': level_params, 

166 'step_params': step_params, 

167 'convergence_controllers': convergence_controllers, 

168 } 

169 

170 # instantiate controller 

171 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) 

172 

173 return description, controller_params, controller 

174 

175 

176def controllerRun(description, controller_params, controller, t0, Tend, exact_event_time_avail=False): 

177 """ 

178 Executes a controller run for a problem defined in the description. 

179 

180 Parameters 

181 ---------- 

182 description : dict 

183 Contains all information for a controller run. 

184 controller_params : dict 

185 Parameters needed for a controller run. 

186 controller : pySDC.core.Controller 

187 Controller to do the stuff. 

188 t0 : float 

189 Starting time of simulation. 

190 Tend : float 

191 End time of simulation. 

192 exact_event_time_avail : bool, optional 

193 Indicates if exact event time of a problem is available. 

194 

195 Returns 

196 ------- 

197 stats : dict 

198 Raw statistics from a controller run. 

199 """ 

200 

201 # get initial values on finest level 

202 P = controller.MS[0].levels[0].prob 

203 uinit = P.u_exact(t0) 

204 t_switch_exact = P.t_switch_exact if exact_event_time_avail else None 

205 

206 # call main function to get things done... 

207 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

208 

209 return stats, t_switch_exact 

210 

211 

212def main(): 

213 r""" 

214 Executes the simulation. 

215 

216 Note 

217 ---- 

218 Hardcoded solutions for battery models in `pySDC.projects.PinTSimE.hardcoded_solutions` are only computed for 

219 ``dt_list=[1e-2, 1e-3]`` and ``M_fix=4``. Hence changing ``dt_list`` and ``M_fix`` to different values could arise 

220 an ``AssertionError``. 

221 """ 

222 

223 # defines parameters for sweeper 

224 M_fix = 4 

225 sweeper_params = { 

226 'num_nodes': M_fix, 

227 'quad_type': 'LOBATTO', 

228 'QI': 'IE', 

229 } 

230 

231 # defines parameters for event detection, restol, and max. number of iterations 

232 handling_params = { 

233 'restol': -1, 

234 'maxiter': 8, 

235 'max_restarts': 50, 

236 'recomputed': False, 

237 'tol_event': 1e-10, 

238 'alpha': 0.96, 

239 'exact_event_time_avail': None, 

240 } 

241 

242 all_params = { 

243 'sweeper_params': sweeper_params, 

244 'handling_params': handling_params, 

245 } 

246 

247 hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize] 

248 

249 use_detection = [True, False] 

250 use_adaptivity = [True, False] 

251 

252 for problem, sweeper in zip([battery, battery_implicit], [imex_1st_order, generic_implicit]): 

253 for defaults in [False, True]: 

254 # for hardcoded solutions problem parameter defaults should match with parameters here 

255 if defaults: 

256 params_battery_1capacitor = { 

257 'ncapacitors': 1, 

258 } 

259 else: 

260 params_battery_1capacitor = { 

261 'ncapacitors': 1, 

262 'C': np.array([1.0]), 

263 'alpha': 1.2, 

264 'V_ref': np.array([1.0]), 

265 } 

266 

267 all_params.update({'problem_params': params_battery_1capacitor}) 

268 

269 _ = runSimulation( 

270 problem=problem, 

271 sweeper=sweeper, 

272 all_params=all_params, 

273 use_adaptivity=use_adaptivity, 

274 use_detection=use_detection, 

275 hook_class=hook_class, 

276 interval=(0.0, 0.3), 

277 dt_list=[1e-2, 1e-3], 

278 nnodes=[M_fix], 

279 ) 

280 

281 # defines parameters for the problem class 

282 params_battery_2capacitors = { 

283 'ncapacitors': 2, 

284 'C': np.array([1.0, 1.0]), 

285 'alpha': 1.2, 

286 'V_ref': np.array([1.0, 1.0]), 

287 } 

288 

289 all_params.update({'problem_params': params_battery_2capacitors}) 

290 

291 _ = runSimulation( 

292 problem=battery_n_capacitors, 

293 sweeper=imex_1st_order, 

294 all_params=all_params, 

295 use_adaptivity=use_adaptivity, 

296 use_detection=use_detection, 

297 hook_class=hook_class, 

298 interval=(0.0, 0.5), 

299 dt_list=[1e-2, 1e-3], 

300 nnodes=[sweeper_params['num_nodes']], 

301 ) 

302 

303 

304def runSimulation(problem, sweeper, all_params, use_adaptivity, use_detection, hook_class, interval, dt_list, nnodes): 

305 r""" 

306 Script that executes the simulation for a given problem class for given parameters defined by the user. 

307 

308 Parameters 

309 ---------- 

310 problem : pySDC.core.Problem 

311 Problem class to be simulated. 

312 sweeper : pySDC.core.Sweeper 

313 Sweeper that is used to simulate the problem class. 

314 all_params : dict 

315 Dictionary contains the problem parameters for ``problem``, the sweeper parameters for ``sweeper``, 

316 and handling parameters needed for event detection, i.e., ``max_restarts``, ``recomputed``, ``tol_event``, 

317 ``alpha``, and ``exact_event_time_available``. 

318 use_adaptivity : list of bool 

319 Indicates whether adaptivity is used in the simulation or not. Here a list is used to iterate over the 

320 different cases, i.e., ``use_adaptivity=[True, False]``. 

321 use_detection : list of bool 

322 Indicates whether event detection is used in the simulation or not. Here a list is used to iterate over the 

323 different cases, i.e., ``use_detection=[True, False]``. 

324 hook_class : list of pySDC.core.Hooks 

325 List containing the different hook classes to log data during the simulation, i.e., ``hook_class=[LogSolution]`` 

326 logs the solution ``u``. 

327 interval : tuple 

328 Simulation interval. 

329 dt_list : list of float 

330 List containing different step sizes where the solution is computed. 

331 nnodes : list of int 

332 The solution can be computed for different number of collocation nodes. 

333 """ 

334 

335 Path("data").mkdir(parents=True, exist_ok=True) 

336 

337 prob_cls_name = problem.__name__ 

338 

339 u_num = {} 

340 

341 for dt in dt_list: 

342 u_num[dt] = {} 

343 

344 for M in nnodes: 

345 u_num[dt][M] = {} 

346 

347 for use_SE in use_detection: 

348 u_num[dt][M][use_SE] = {} 

349 

350 for use_A in use_adaptivity: 

351 u_num[dt][M][use_SE][use_A] = {} 

352 

353 problem_params = all_params['problem_params'] 

354 sweeper_params = all_params['sweeper_params'] 

355 handling_params = all_params['handling_params'] 

356 

357 # plotting results for fixed M requires that M_fix is included in nnodes! 

358 M_fix = sweeper_params['num_nodes'] 

359 assert ( 

360 M_fix in nnodes 

361 ), f"For fixed number of collocation nodes {M_fix} no solution will be computed!" 

362 

363 restol = -1 if use_A else handling_params['restol'] 

364 

365 description, controller_params, controller = generateDescription( 

366 dt=dt, 

367 problem=problem, 

368 sweeper=sweeper, 

369 num_nodes=M, 

370 quad_type=sweeper_params['quad_type'], 

371 QI=sweeper_params['QI'], 

372 hook_class=hook_class, 

373 use_adaptivity=use_A, 

374 use_switch_estimator=use_SE, 

375 problem_params=problem_params, 

376 restol=restol, 

377 maxiter=handling_params['maxiter'], 

378 max_restarts=handling_params['max_restarts'], 

379 tol_event=handling_params['tol_event'], 

380 alpha=handling_params['alpha'], 

381 ) 

382 

383 stats, t_switch_exact = controllerRun( 

384 description=description, 

385 controller_params=controller_params, 

386 controller=controller, 

387 t0=interval[0], 

388 Tend=interval[-1], 

389 exact_event_time_avail=handling_params['exact_event_time_avail'], 

390 ) 

391 

392 u_num[dt][M][use_SE][use_A] = getDataDict( 

393 stats, prob_cls_name, use_A, use_SE, handling_params['recomputed'], t_switch_exact 

394 ) 

395 

396 plotSolution(u_num[dt][M][use_SE][use_A], prob_cls_name, use_A, use_SE) 

397 

398 testSolution(u_num[dt][M_fix][use_SE][use_A], prob_cls_name, dt, use_A, use_SE) 

399 

400 return u_num 

401 

402 

403def getUnknownLabels(prob_cls_name): 

404 """ 

405 Returns the unknown for a problem and corresponding labels for a plot. 

406 

407 Parameters 

408 ---------- 

409 prob_cls_name : str 

410 Name of the problem class. 

411 

412 Returns 

413 ------- 

414 unknowns : list of str 

415 Contains the names of unknowns. 

416 unknowns_labels : list of str 

417 Contains the labels of unknowns for plotting. 

418 """ 

419 

420 unknowns = { 

421 'battery': ['iL', 'vC'], 

422 'battery_implicit': ['iL', 'vC'], 

423 'battery_n_capacitors': ['iL', 'vC1', 'vC2'], 

424 'DiscontinuousTestODE': ['u'], 

425 'piline': ['vC1', 'vC2', 'iLp'], 

426 'buck_converter': ['vC1', 'vC2', 'iLp'], 

427 } 

428 

429 unknowns_labels = { 

430 'battery': [r'$i_L$', r'$v_C$'], 

431 'battery_implicit': [r'$i_L$', r'$v_C$'], 

432 'battery_n_capacitors': [r'$i_L$', r'$v_{C_1}$', r'$v_{C_2}$'], 

433 'DiscontinuousTestODE': [r'$u$'], 

434 'piline': [r'$v_{C_1}$', r'$v_{C_2}$', r'$i_{L_\pi}$'], 

435 'buck_converter': [r'$v_{C_1}$', r'$v_{C_2}$', r'$i_{L_\pi}$'], 

436 } 

437 

438 return unknowns[prob_cls_name], unknowns_labels[prob_cls_name] 

439 

440 

441def plotStylingStuff(): # pragma: no cover 

442 """ 

443 Returns plot stuff such as colors, line styles for making plots more pretty. 

444 """ 

445 

446 colors = { 

447 False: { 

448 False: 'dodgerblue', 

449 True: 'navy', 

450 }, 

451 True: { 

452 False: 'linegreen', 

453 True: 'darkgreen', 

454 }, 

455 } 

456 

457 return colors 

458 

459 

460def plotSolution(u_num, prob_cls_name, use_adaptivity, use_detection): # pragma: no cover 

461 r""" 

462 Plots the numerical solution for one simulation run. 

463 

464 Parameters 

465 ---------- 

466 u_num : dict 

467 Contains numerical solution with corresponding times for different problem_classes, and 

468 labels for different unknowns of the problem. 

469 prob_cls_name : str 

470 Name of the problem class to be plotted. 

471 use_adaptivity : bool 

472 Indicates whether adaptivity is used in the simulation or not. 

473 """ 

474 

475 fig, ax = plt_helper.plt.subplots(1, 1, figsize=(7.5, 5)) 

476 

477 unknowns = u_num['unknowns'] 

478 unknowns_labels = u_num['unknowns_labels'] 

479 for unknown, unknown_label in zip(unknowns, unknowns_labels): 

480 ax.plot(u_num['t'], u_num[unknown], label=unknown_label) 

481 

482 if use_detection: 

483 t_switches = u_num['t_switches'] 

484 for i in range(len(t_switches)): 

485 ax.axvline(x=t_switches[i], linestyle='--', linewidth=0.8, color='r', label='Event {}'.format(i + 1)) 

486 

487 if use_adaptivity: 

488 dt_ax = ax.twinx() 

489 dt = u_num['dt'] 

490 dt_ax.plot(dt[:, 0], dt[:, 1], linestyle='-', linewidth=0.8, color='k', label=r'$\Delta t$') 

491 dt_ax.set_ylabel(r'$\Delta t$', fontsize=16) 

492 dt_ax.legend(frameon=False, fontsize=12, loc='center right') 

493 

494 ax.legend(frameon=False, fontsize=12, loc='upper right') 

495 ax.set_xlabel(r'$t$', fontsize=16) 

496 ax.set_ylabel(r'$u(t)$', fontsize=16) 

497 

498 fig.savefig(f'data/{prob_cls_name}_model_solution.png', dpi=300, bbox_inches='tight') 

499 plt_helper.plt.close(fig) 

500 

501 

502def getDataDict(stats, prob_cls_name, use_adaptivity, use_detection, recomputed, t_switch_exact): 

503 r""" 

504 Extracts statistics and store it in a dictionary. In this routine, from ``stats`` different data are extracted 

505 such as 

506 

507 - each component of solution ``'u'`` and corresponding time domain ``'t'``, 

508 - the unknowns of the problem ``'unknowns'``, 

509 - the unknowns of the problem as labels for plotting ``'unknowns_labels'``, 

510 - global error ``'e_global'`` after each step, 

511 - events found by event detection ``'t_switches''``, 

512 - exact event time ``'t_switch_exact'``, 

513 - event error ``'e_event'``, 

514 - state function ``'state_function'``, 

515 - embedded error estimate computing when using adaptivity ``'e_em'``, 

516 - (adjusted) step sizes ``'dt'``, 

517 - sum over restarts ``'sum_restarts'``, 

518 - and the sum over all iterations ``'sum_niters'``. 

519 

520 Note 

521 ---- 

522 In order to use these data, corresponding hook classes has to be defined before the simulation. Otherwise, no values can 

523 be obtained. 

524 

525 The global error does only make sense when an exact solution for the problem is available. Since ``'e_global'`` is stored 

526 for each problem class, only for ``DiscontinuousTestODE`` the global error is taken into account when testing the solution. 

527 

528 Also the event error ``'e_event'`` can only be computed if an exact event time is available. Since the function 

529 ``controllerRun`` returns ``t_switch_exact=None`` when no exact event time is available, in order to compute the event error, 

530 it has to be proven whether the list (in case of more than one event) contains ``None`` or not. 

531 

532 Parameters 

533 ---------- 

534 stats : dict 

535 Raw statistics of one simulation run. 

536 prob_cls_name : str 

537 Name of the problem class. 

538 use_adaptivity : bool 

539 Indicates whether adaptivity is used in the simulation or not. 

540 use_detection : bool 

541 Indicates whether event detection is used or not. 

542 recomputed : bool 

543 Indicates if values after successfully steps are used or not. 

544 t_switch_exact : float 

545 Exact event time of the problem. 

546 

547 Returns 

548 ------- 

549 res : dict 

550 Dictionary with extracted data separated with reasonable keys. 

551 """ 

552 

553 res = {} 

554 unknowns, unknowns_labels = getUnknownLabels(prob_cls_name) 

555 

556 # numerical solution 

557 u_val = get_sorted(stats, type='u', sortby='time', recomputed=recomputed) 

558 res['t'] = np.array([item[0] for item in u_val]) 

559 for i, label in enumerate(unknowns): 

560 res[label] = np.array([item[1][i] for item in u_val]) 

561 

562 res['unknowns'] = unknowns 

563 res['unknowns_labels'] = unknowns_labels 

564 

565 # global error 

566 res['e_global'] = np.array(get_sorted(stats, type='e_global_post_step', sortby='time', recomputed=recomputed)) 

567 

568 # event time(s) found by event detection 

569 if use_detection: 

570 switches = get_sorted(stats, type='switch', sortby='time', recomputed=recomputed) 

571 assert len(switches) >= 1, 'No events found!' 

572 t_switches = [t[1] for t in switches] 

573 res['t_switches'] = t_switches 

574 

575 t_switch_exact = [t_switch_exact] 

576 res['t_switch_exact'] = t_switch_exact 

577 

578 if not all(t is None for t in t_switch_exact): 

579 event_err = [ 

580 abs(num_item - ex_item) for (num_item, ex_item) in zip(res['t_switches'], res['t_switch_exact']) 

581 ] 

582 res['e_event'] = event_err 

583 

584 h_val = get_sorted(stats, type='state_function', sortby='time', recomputed=recomputed) 

585 h = np.array([np.abs(val[1]) for val in h_val]) 

586 res['state_function'] = h 

587 

588 # embedded error and adapted step sizes 

589 if use_adaptivity: 

590 res['e_em'] = np.array(get_sorted(stats, type='error_embedded_estimate', sortby='time', recomputed=recomputed)) 

591 res['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed)) 

592 

593 # sum over restarts 

594 if use_adaptivity or use_detection: 

595 res['sum_restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) 

596 

597 # sum over all iterations 

598 res['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) 

599 return res 

600 

601 

602if __name__ == "__main__": 

603 main()