Coverage for pySDC / projects / PinTSimE / estimation_check.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 07:06 +0000

1import numpy as np 

2from pathlib import Path 

3 

4from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted 

5from pySDC.implementations.problem_classes.Battery import battery_n_capacitors, battery, battery_implicit 

6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

8 

9from pySDC.projects.PinTSimE.battery_model import runSimulation, plotStylingStuff 

10 

11import pySDC.helpers.plot_helper as plt_helper 

12 

13from pySDC.projects.PinTSimE.battery_model import LogEventBattery 

14from pySDC.implementations.hooks.log_solution import LogSolution 

15from pySDC.implementations.hooks.log_step_size import LogStepSize 

16from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate 

17 

18 

19def run_estimation_check(): 

20 r""" 

21 Generates plots to visualise results applying the Switch Estimator and Adaptivity to the battery models 

22 containing. 

23 

24 Note 

25 ---- 

26 Hardcoded solutions for battery models in `pySDC.projects.PinTSimE.hardcoded_solutions` are only computed for 

27 ``dt_list=[1e-2, 1e-3]`` and ``M_fix=4``. Hence changing ``dt_list`` and ``M_fix`` to different values could arise 

28 an ``AssertionError``. 

29 """ 

30 

31 Path("data").mkdir(parents=True, exist_ok=True) 

32 

33 # --- defines parameters for sweeper ---- 

34 M_fix = 4 

35 sweeper_params = { 

36 'num_nodes': M_fix, 

37 'quad_type': 'LOBATTO', 

38 'QI': 'IE', 

39 } 

40 

41 # --- defines parameters for event detection and maximum number of iterations ---- 

42 handling_params = { 

43 'restol': -1, 

44 'maxiter': 8, 

45 'max_restarts': 50, 

46 'recomputed': False, 

47 'tol_event': 1e-10, 

48 'alpha': 0.96, 

49 'exact_event_time_avail': None, 

50 } 

51 

52 problem_classes = [battery, battery_implicit, battery_n_capacitors] 

53 prob_class_names = [cls.__name__ for cls in problem_classes] 

54 sweeper_classes = [imex_1st_order, generic_implicit, imex_1st_order] 

55 

56 # --- defines parameters for battery models ---- 

57 params_battery_1capacitor = { 

58 'ncapacitors': 1, 

59 'C': np.array([1.0]), 

60 'alpha': 1.2, 

61 'V_ref': np.array([1.0]), 

62 } 

63 

64 params_battery_2capacitors = { 

65 'ncapacitors': 2, 

66 'C': np.array([1.0, 1.0]), 

67 'alpha': 1.2, 

68 'V_ref': np.array([1.0, 1.0]), 

69 } 

70 

71 # --- parameters for each problem class are stored in this dictionary ---- 

72 all_params = { 

73 'battery': { 

74 'sweeper_params': sweeper_params, 

75 'handling_params': handling_params, 

76 'problem_params': params_battery_1capacitor, 

77 }, 

78 'battery_implicit': { 

79 'sweeper_params': sweeper_params, 

80 'handling_params': handling_params, 

81 'problem_params': params_battery_1capacitor, 

82 }, 

83 'battery_n_capacitors': { 

84 'sweeper_params': sweeper_params, 

85 'handling_params': handling_params, 

86 'problem_params': params_battery_2capacitors, 

87 }, 

88 } 

89 

90 # ---- simulation domain for each problem class ---- 

91 interval = { 

92 'battery': (0.0, 0.3), 

93 'battery_implicit': (0.0, 0.3), 

94 'battery_n_capacitors': (0.0, 0.5), 

95 } 

96 

97 hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize] 

98 

99 use_detection = [True, False] 

100 use_adaptivity = [True, False] 

101 

102 for problem, sweeper, prob_cls_name in zip(problem_classes, sweeper_classes, prob_class_names, strict=True): 

103 u_num = runSimulation( 

104 problem=problem, 

105 sweeper=sweeper, 

106 all_params=all_params[prob_cls_name], 

107 use_adaptivity=use_adaptivity, 

108 use_detection=use_detection, 

109 hook_class=hook_class, 

110 interval=interval[prob_cls_name], 

111 dt_list=[1e-2, 1e-3], 

112 nnodes=[M_fix], 

113 ) 

114 

115 plotAccuracyCheck(u_num, prob_cls_name, M_fix) 

116 

117 # plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix) 

118 

119 plotStateFunctionOverTime(u_num, prob_cls_name, M_fix) 

120 

121 

122def plotAccuracyCheck(u_num, prob_cls_name, M_fix): # pragma: no cover 

123 r""" 

124 Routine to check accuracy for different step sizes in case of using adaptivity. 

125 

126 Parameters 

127 ---------- 

128 u_num : dict 

129 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``, 

130 where for each step size ``dt``, for each number of collocation node ``M``, for each 

131 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored. 

132 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``. 

133 prob_cls_name : str 

134 Name of the problem class. 

135 M_fix : int 

136 Fixed number of collocation nodes the plot is generated for. 

137 """ 

138 

139 colors = plotStylingStuff() 

140 dt_list = u_num.keys() 

141 

142 use_A = True 

143 for dt in dt_list: 

144 fig, ax = plt_helper.plt.subplots(1, 1, figsize=(7.5, 5), sharex='col', sharey='row') 

145 e_ax = ax.twinx() 

146 for use_SE in u_num[dt][M_fix].keys(): 

147 dt_val = u_num[dt][M_fix][use_SE][use_A]['dt'] 

148 e_em_val = u_num[dt][M_fix][use_SE][use_A]['e_em'] 

149 if use_SE: 

150 t_switches = u_num[dt][M_fix][use_SE][use_A]['t_switches'] 

151 

152 for i in range(len(t_switches)): 

153 ax.axvline(x=t_switches[i], linestyle='--', color='tomato', label='Event {}'.format(i + 1)) 

154 

155 ax.plot(dt_val[:, 0], dt_val[:, 1], color=colors[use_SE][use_A], label=r'SE={}, A={}'.format(use_SE, use_A)) 

156 

157 e_ax.plot(e_em_val[:, 0], e_em_val[:, 1], linestyle='dashdot', color=colors[use_SE][use_A]) 

158 

159 ax.plot(0, 0, color='black', linestyle='solid', label=r'$\Delta t_\mathrm{adapt}$') 

160 ax.plot(0, 0, color='black', linestyle='dashdot', label=r'$e_{em}$') 

161 

162 e_ax.set_yscale('log', base=10) 

163 e_ax.set_ylabel(r'Embedded error estimate $e_{em}$', fontsize=16) 

164 e_ax.set_ylim(1e-16, 1e-7) 

165 e_ax.tick_params(labelsize=16) 

166 e_ax.minorticks_off() 

167 

168 ax.tick_params(axis='both', which='major', labelsize=16) 

169 ax.set_ylim(1e-9, 1e0) 

170 ax.set_yscale('log', base=10) 

171 ax.set_xlabel(r'Time $t$', fontsize=16) 

172 ax.set_ylabel(r'Adapted step sizes $\Delta t_\mathrm{adapt}$', fontsize=16) 

173 ax.grid(visible=True) 

174 ax.minorticks_off() 

175 ax.legend(frameon=True, fontsize=12, loc='center left') 

176 

177 fig.savefig( 

178 'data/detection_and_adaptivity_{}_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix), 

179 dpi=300, 

180 bbox_inches='tight', 

181 ) 

182 plt_helper.plt.close(fig) 

183 

184 

185def plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix): # pragma: no cover 

186 r""" 

187 Routine that plots the state function at time before the event, exactly at the event, and after the event. Note 

188 that this routine does make sense only for a state function that remains constant after the event. 

189 

190 TODO: Function still does not work as expected. Every time when the switch estimator is adapted, the tolerances 

191 does not suit anymore! 

192 

193 Parameters 

194 ---------- 

195 u_num : dict 

196 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``, 

197 where for each step size ``dt``, for each number of collocation node ``M``, for each 

198 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored. 

199 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``. 

200 prob_cls_name : str 

201 Name of the problem class. 

202 M_fix : int 

203 Fixed number of collocation nodes the plot is generated for. 

204 """ 

205 

206 title_cases = { 

207 0: 'Using detection', 

208 1: 'Using adaptivity', 

209 2: 'Using adaptivity and detection', 

210 } 

211 

212 dt_list = list(u_num.keys()) 

213 use_detection = u_num[list(dt_list)[0]][M_fix].keys() 

214 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys() 

215 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function'] 

216 n = h0[0].shape[0] 

217 

218 for i in range(n): 

219 fig, ax = plt_helper.plt.subplots(1, 3, figsize=(12, 4), sharex='col', sharey='row', squeeze=False) 

220 dt_list = list(u_num.keys()) 

221 for use_SE in use_detection: 

222 for use_A in use_adaptivity: 

223 # ---- decide whether state function (w/o handling) has two entries or one; choose correct one with reshaping ---- 

224 h_val_no_handling = [u_num[dt][M_fix][False][False]['state_function'] for dt in dt_list] 

225 h_no_handling = [item[:] if n == 1 else item[:, i] for item in h_val_no_handling] 

226 h_no_handling = [item.reshape((item.shape[0],)) for item in h_no_handling] 

227 

228 t_no_handling = [u_num[dt][M_fix][False][False]['t'] for dt in dt_list] 

229 

230 if not use_A and not use_SE: 

231 continue 

232 else: 

233 ind = 0 if (not use_A and use_SE) else (1 if (use_A and not use_SE) else 2) 

234 ax[0, ind].set_title(r'{} for $n={}$'.format(title_cases[ind], i + 1)) 

235 

236 # ---- same is done here for state function of other cases ---- 

237 h_val = [u_num[dt][M_fix][use_SE][use_A]['state_function'] for dt in dt_list] 

238 h = [item[:] if n == 1 else item[:, i] for item in h_val] 

239 h = [item.reshape((item.shape[0],)) for item in h] 

240 

241 t = [u_num[dt][M_fix][use_SE][use_A]['t'] for dt in dt_list] 

242 

243 if use_SE: 

244 t_switches = [u_num[dt][M_fix][use_SE][use_A]['t_switches'] for dt in dt_list] 

245 for t_switch_item in t_switches: 

246 mask = np.append([True], np.abs(t_switch_item[1:] - t_switch_item[:-1]) > 1e-10) 

247 t_switch_item = t_switch_item[mask] 

248 

249 t_switch = [t_event[i] for t_event in t_switches] 

250 ax[0, ind].plot( 

251 dt_list, 

252 [ 

253 h_item[m] 

254 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch, strict=True) 

255 for m in range(len(t_item)) 

256 if abs(t_item[m] - t_switch_item) <= 2.7961188919789493e-11 

257 ], 

258 color='limegreen', 

259 marker='s', 

260 linestyle='solid', 

261 alpha=0.4, 

262 label='At event', 

263 ) 

264 

265 ax[0, ind].plot( 

266 dt_list, 

267 [ 

268 h_item[m - 1] 

269 for (t_item, h_item, t_switch_item) in zip( 

270 t_no_handling, h_no_handling, t_switch, strict=True 

271 ) 

272 for m in range(1, len(t_item)) 

273 if t_item[m - 1] < t_switch_item < t_item[m] 

274 ], 

275 color='firebrick', 

276 marker='o', 

277 linestyle='solid', 

278 alpha=0.4, 

279 label='Before event', 

280 ) 

281 

282 ax[0, ind].plot( 

283 dt_list, 

284 [ 

285 h_item[m] 

286 for (t_item, h_item, t_switch_item) in zip( 

287 t_no_handling, h_no_handling, t_switch, strict=True 

288 ) 

289 for m in range(1, len(t_item)) 

290 if t_item[m - 1] < t_switch_item < t_item[m] 

291 ], 

292 color='deepskyblue', 

293 marker='*', 

294 linestyle='solid', 

295 alpha=0.4, 

296 label='After event', 

297 ) 

298 

299 else: 

300 ax[0, ind].plot( 

301 dt_list, 

302 [ 

303 h_item[m - 1] 

304 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch, strict=True) 

305 for m in range(1, len(t_item)) 

306 if t_item[m - 1] < t_switch_item < t_item[m] 

307 ], 

308 color='firebrick', 

309 marker='o', 

310 linestyle='solid', 

311 alpha=0.4, 

312 label='Before event', 

313 ) 

314 

315 ax[0, ind].plot( 

316 dt_list, 

317 [ 

318 h_item[m] 

319 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch, strict=True) 

320 for m in range(1, len(t_item)) 

321 if t_item[m - 1] < t_switch_item < t_item[m] 

322 ], 

323 color='deepskyblue', 

324 marker='*', 

325 linestyle='solid', 

326 alpha=0.4, 

327 label='After event', 

328 ) 

329 

330 ax[0, ind].tick_params(axis='both', which='major', labelsize=16) 

331 ax[0, ind].set_xticks(dt_list) 

332 ax[0, ind].set_xticklabels(dt_list) 

333 ax[0, ind].set_ylim(1e-15, 1e1) 

334 ax[0, ind].set_yscale('log', base=10) 

335 ax[0, ind].set_xlabel(r'Step size $\Delta t$', fontsize=16) 

336 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16) 

337 ax[0, ind].grid(visible=True) 

338 ax[0, ind].minorticks_off() 

339 ax[0, ind].legend(frameon=True, fontsize=12, loc='lower left') 

340 

341 fig.savefig( 

342 'data/{}_comparison_event{}_M={}.png'.format(prob_cls_name, i + 1, M_fix), dpi=300, bbox_inches='tight' 

343 ) 

344 plt_helper.plt.close(fig) 

345 

346 

347def plotStateFunctionOverTime(u_num, prob_cls_name, M_fix): # pragma: no cover 

348 r""" 

349 Routine that plots the state function over time. 

350 

351 Parameters 

352 ---------- 

353 u_num : dict 

354 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``, 

355 where for each step size ``dt``, for each number of collocation node ``M``, for each 

356 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored. 

357 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``. 

358 prob_cls_name : str 

359 Indicates the name of the problem class to be considered. 

360 M_fix : int 

361 Fixed number of collocation nodes the plot is generated for. 

362 """ 

363 

364 dt_list = u_num.keys() 

365 use_detection = u_num[list(dt_list)[0]][M_fix].keys() 

366 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys() 

367 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function'] 

368 n = h0[0].shape[0] 

369 for dt in dt_list: 

370 figsize = (7.5, 5) if n == 1 else (12, 5) 

371 fig, ax = plt_helper.plt.subplots(1, n, figsize=figsize, sharex='col', sharey='row', squeeze=False) 

372 

373 for use_SE in use_detection: 

374 for use_A in use_adaptivity: 

375 t = u_num[dt][M_fix][use_SE][use_A]['t'] 

376 h_val = u_num[dt][M_fix][use_SE][use_A]['state_function'] 

377 

378 linestyle = 'dashdot' if use_A else 'dotted' 

379 for i in range(n): 

380 h = h_val[:] if n == 1 else h_val[:, i] 

381 ax[0, i].set_title(r'$n={}$'.format(i + 1)) 

382 ax[0, i].plot( 

383 t, h, linestyle=linestyle, label='Detection: {}, Adaptivity: {}'.format(use_SE, use_A) 

384 ) 

385 

386 ax[0, i].tick_params(axis='both', which='major', labelsize=16) 

387 ax[0, i].set_ylim(1e-15, 1e0) 

388 ax[0, i].set_yscale('log', base=10) 

389 ax[0, i].set_xlabel(r'Time $t$', fontsize=16) 

390 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16) 

391 ax[0, i].grid(visible=True) 

392 ax[0, i].minorticks_off() 

393 ax[0, i].legend(frameon=True, fontsize=12, loc='lower left') 

394 

395 fig.savefig( 

396 'data/{}_state_function_over_time_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix), 

397 dpi=300, 

398 bbox_inches='tight', 

399 ) 

400 plt_helper.plt.close(fig) 

401 

402 

403if __name__ == "__main__": 

404 run_estimation_check()