Coverage for pySDC/projects/PinTSimE/estimation_check.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import numpy as np 

2from pathlib import Path 

3 

4from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted 

5from pySDC.implementations.problem_classes.Battery import battery_n_capacitors, battery, battery_implicit 

6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

8 

9from pySDC.projects.PinTSimE.battery_model import runSimulation, plotStylingStuff 

10 

11import pySDC.helpers.plot_helper as plt_helper 

12 

13from pySDC.projects.PinTSimE.battery_model import LogEventBattery 

14from pySDC.implementations.hooks.log_solution import LogSolution 

15from pySDC.implementations.hooks.log_step_size import LogStepSize 

16from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate 

17 

18 

19def run_estimation_check(): 

20 r""" 

21 Generates plots to visualise results applying the Switch Estimator and Adaptivity to the battery models 

22 containing. 

23 

24 Note 

25 ---- 

26 Hardcoded solutions for battery models in `pySDC.projects.PinTSimE.hardcoded_solutions` are only computed for 

27 ``dt_list=[1e-2, 1e-3]`` and ``M_fix=4``. Hence changing ``dt_list`` and ``M_fix`` to different values could arise 

28 an ``AssertionError``. 

29 """ 

30 

31 Path("data").mkdir(parents=True, exist_ok=True) 

32 

33 # --- defines parameters for sweeper ---- 

34 M_fix = 4 

35 sweeper_params = { 

36 'num_nodes': M_fix, 

37 'quad_type': 'LOBATTO', 

38 'QI': 'IE', 

39 } 

40 

41 # --- defines parameters for event detection and maximum number of iterations ---- 

42 handling_params = { 

43 'restol': -1, 

44 'maxiter': 8, 

45 'max_restarts': 50, 

46 'recomputed': False, 

47 'tol_event': 1e-10, 

48 'alpha': 0.96, 

49 'exact_event_time_avail': None, 

50 } 

51 

52 problem_classes = [battery, battery_implicit, battery_n_capacitors] 

53 prob_class_names = [cls.__name__ for cls in problem_classes] 

54 sweeper_classes = [imex_1st_order, generic_implicit, imex_1st_order] 

55 

56 # --- defines parameters for battery models ---- 

57 params_battery_1capacitor = { 

58 'ncapacitors': 1, 

59 'C': np.array([1.0]), 

60 'alpha': 1.2, 

61 'V_ref': np.array([1.0]), 

62 } 

63 

64 params_battery_2capacitors = { 

65 'ncapacitors': 2, 

66 'C': np.array([1.0, 1.0]), 

67 'alpha': 1.2, 

68 'V_ref': np.array([1.0, 1.0]), 

69 } 

70 

71 # --- parameters for each problem class are stored in this dictionary ---- 

72 all_params = { 

73 'battery': { 

74 'sweeper_params': sweeper_params, 

75 'handling_params': handling_params, 

76 'problem_params': params_battery_1capacitor, 

77 }, 

78 'battery_implicit': { 

79 'sweeper_params': sweeper_params, 

80 'handling_params': handling_params, 

81 'problem_params': params_battery_1capacitor, 

82 }, 

83 'battery_n_capacitors': { 

84 'sweeper_params': sweeper_params, 

85 'handling_params': handling_params, 

86 'problem_params': params_battery_2capacitors, 

87 }, 

88 } 

89 

90 # ---- simulation domain for each problem class ---- 

91 interval = { 

92 'battery': (0.0, 0.3), 

93 'battery_implicit': (0.0, 0.3), 

94 'battery_n_capacitors': (0.0, 0.5), 

95 } 

96 

97 hook_class = [LogSolution, LogEventBattery, LogEmbeddedErrorEstimate, LogStepSize] 

98 

99 use_detection = [True, False] 

100 use_adaptivity = [True, False] 

101 

102 for problem, sweeper, prob_cls_name in zip(problem_classes, sweeper_classes, prob_class_names): 

103 u_num = runSimulation( 

104 problem=problem, 

105 sweeper=sweeper, 

106 all_params=all_params[prob_cls_name], 

107 use_adaptivity=use_adaptivity, 

108 use_detection=use_detection, 

109 hook_class=hook_class, 

110 interval=interval[prob_cls_name], 

111 dt_list=[1e-2, 1e-3], 

112 nnodes=[M_fix], 

113 ) 

114 

115 plotAccuracyCheck(u_num, prob_cls_name, M_fix) 

116 

117 # plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix) 

118 

119 plotStateFunctionOverTime(u_num, prob_cls_name, M_fix) 

120 

121 

122def plotAccuracyCheck(u_num, prob_cls_name, M_fix): # pragma: no cover 

123 r""" 

124 Routine to check accuracy for different step sizes in case of using adaptivity. 

125 

126 Parameters 

127 ---------- 

128 u_num : dict 

129 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``, 

130 where for each step size ``dt``, for each number of collocation node ``M``, for each 

131 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored. 

132 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``. 

133 prob_cls_name : str 

134 Name of the problem class. 

135 M_fix : int 

136 Fixed number of collocation nodes the plot is generated for. 

137 """ 

138 

139 colors = plotStylingStuff() 

140 dt_list = u_num.keys() 

141 

142 use_A = True 

143 for dt in dt_list: 

144 fig, ax = plt_helper.plt.subplots(1, 1, figsize=(7.5, 5), sharex='col', sharey='row') 

145 e_ax = ax.twinx() 

146 for use_SE in u_num[dt][M_fix].keys(): 

147 dt_val = u_num[dt][M_fix][use_SE][use_A]['dt'] 

148 e_em_val = u_num[dt][M_fix][use_SE][use_A]['e_em'] 

149 if use_SE: 

150 t_switches = u_num[dt][M_fix][use_SE][use_A]['t_switches'] 

151 

152 for i in range(len(t_switches)): 

153 ax.axvline(x=t_switches[i], linestyle='--', color='tomato', label='Event {}'.format(i + 1)) 

154 

155 ax.plot(dt_val[:, 0], dt_val[:, 1], color=colors[use_SE][use_A], label=r'SE={}, A={}'.format(use_SE, use_A)) 

156 

157 e_ax.plot(e_em_val[:, 0], e_em_val[:, 1], linestyle='dashdot', color=colors[use_SE][use_A]) 

158 

159 ax.plot(0, 0, color='black', linestyle='solid', label=r'$\Delta t_\mathrm{adapt}$') 

160 ax.plot(0, 0, color='black', linestyle='dashdot', label=r'$e_{em}$') 

161 

162 e_ax.set_yscale('log', base=10) 

163 e_ax.set_ylabel(r'Embedded error estimate $e_{em}$', fontsize=16) 

164 e_ax.set_ylim(1e-16, 1e-7) 

165 e_ax.tick_params(labelsize=16) 

166 e_ax.minorticks_off() 

167 

168 ax.tick_params(axis='both', which='major', labelsize=16) 

169 ax.set_ylim(1e-9, 1e0) 

170 ax.set_yscale('log', base=10) 

171 ax.set_xlabel(r'Time $t$', fontsize=16) 

172 ax.set_ylabel(r'Adapted step sizes $\Delta t_\mathrm{adapt}$', fontsize=16) 

173 ax.grid(visible=True) 

174 ax.minorticks_off() 

175 ax.legend(frameon=True, fontsize=12, loc='center left') 

176 

177 fig.savefig( 

178 'data/detection_and_adaptivity_{}_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix), 

179 dpi=300, 

180 bbox_inches='tight', 

181 ) 

182 plt_helper.plt.close(fig) 

183 

184 

185def plotStateFunctionAroundEvent(u_num, prob_cls_name, M_fix): # pragma: no cover 

186 r""" 

187 Routine that plots the state function at time before the event, exactly at the event, and after the event. Note 

188 that this routine does make sense only for a state function that remains constant after the event. 

189 

190 TODO: Function still does not work as expected. Every time when the switch estimator is adapted, the tolerances 

191 does not suit anymore! 

192 

193 Parameters 

194 ---------- 

195 u_num : dict 

196 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``, 

197 where for each step size ``dt``, for each number of collocation node ``M``, for each 

198 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored. 

199 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``. 

200 prob_cls_name : str 

201 Name of the problem class. 

202 M_fix : int 

203 Fixed number of collocation nodes the plot is generated for. 

204 """ 

205 

206 title_cases = { 

207 0: 'Using detection', 

208 1: 'Using adaptivity', 

209 2: 'Using adaptivity and detection', 

210 } 

211 

212 dt_list = list(u_num.keys()) 

213 use_detection = u_num[list(dt_list)[0]][M_fix].keys() 

214 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys() 

215 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function'] 

216 n = h0[0].shape[0] 

217 

218 for i in range(n): 

219 fig, ax = plt_helper.plt.subplots(1, 3, figsize=(12, 4), sharex='col', sharey='row', squeeze=False) 

220 dt_list = list(u_num.keys()) 

221 for use_SE in use_detection: 

222 for use_A in use_adaptivity: 

223 # ---- decide whether state function (w/o handling) has two entries or one; choose correct one with reshaping ---- 

224 h_val_no_handling = [u_num[dt][M_fix][False][False]['state_function'] for dt in dt_list] 

225 h_no_handling = [item[:] if n == 1 else item[:, i] for item in h_val_no_handling] 

226 h_no_handling = [item.reshape((item.shape[0],)) for item in h_no_handling] 

227 

228 t_no_handling = [u_num[dt][M_fix][False][False]['t'] for dt in dt_list] 

229 

230 if not use_A and not use_SE: 

231 continue 

232 else: 

233 ind = 0 if (not use_A and use_SE) else (1 if (use_A and not use_SE) else 2) 

234 ax[0, ind].set_title(r'{} for $n={}$'.format(title_cases[ind], i + 1)) 

235 

236 # ---- same is done here for state function of other cases ---- 

237 h_val = [u_num[dt][M_fix][use_SE][use_A]['state_function'] for dt in dt_list] 

238 h = [item[:] if n == 1 else item[:, i] for item in h_val] 

239 h = [item.reshape((item.shape[0],)) for item in h] 

240 

241 t = [u_num[dt][M_fix][use_SE][use_A]['t'] for dt in dt_list] 

242 

243 if use_SE: 

244 t_switches = [u_num[dt][M_fix][use_SE][use_A]['t_switches'] for dt in dt_list] 

245 for t_switch_item in t_switches: 

246 mask = np.append([True], np.abs(t_switch_item[1:] - t_switch_item[:-1]) > 1e-10) 

247 t_switch_item = t_switch_item[mask] 

248 

249 t_switch = [t_event[i] for t_event in t_switches] 

250 ax[0, ind].plot( 

251 dt_list, 

252 [ 

253 h_item[m] 

254 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch) 

255 for m in range(len(t_item)) 

256 if abs(t_item[m] - t_switch_item) <= 2.7961188919789493e-11 

257 ], 

258 color='limegreen', 

259 marker='s', 

260 linestyle='solid', 

261 alpha=0.4, 

262 label='At event', 

263 ) 

264 

265 ax[0, ind].plot( 

266 dt_list, 

267 [ 

268 h_item[m - 1] 

269 for (t_item, h_item, t_switch_item) in zip(t_no_handling, h_no_handling, t_switch) 

270 for m in range(1, len(t_item)) 

271 if t_item[m - 1] < t_switch_item < t_item[m] 

272 ], 

273 color='firebrick', 

274 marker='o', 

275 linestyle='solid', 

276 alpha=0.4, 

277 label='Before event', 

278 ) 

279 

280 ax[0, ind].plot( 

281 dt_list, 

282 [ 

283 h_item[m] 

284 for (t_item, h_item, t_switch_item) in zip(t_no_handling, h_no_handling, t_switch) 

285 for m in range(1, len(t_item)) 

286 if t_item[m - 1] < t_switch_item < t_item[m] 

287 ], 

288 color='deepskyblue', 

289 marker='*', 

290 linestyle='solid', 

291 alpha=0.4, 

292 label='After event', 

293 ) 

294 

295 else: 

296 ax[0, ind].plot( 

297 dt_list, 

298 [ 

299 h_item[m - 1] 

300 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch) 

301 for m in range(1, len(t_item)) 

302 if t_item[m - 1] < t_switch_item < t_item[m] 

303 ], 

304 color='firebrick', 

305 marker='o', 

306 linestyle='solid', 

307 alpha=0.4, 

308 label='Before event', 

309 ) 

310 

311 ax[0, ind].plot( 

312 dt_list, 

313 [ 

314 h_item[m] 

315 for (t_item, h_item, t_switch_item) in zip(t, h, t_switch) 

316 for m in range(1, len(t_item)) 

317 if t_item[m - 1] < t_switch_item < t_item[m] 

318 ], 

319 color='deepskyblue', 

320 marker='*', 

321 linestyle='solid', 

322 alpha=0.4, 

323 label='After event', 

324 ) 

325 

326 ax[0, ind].tick_params(axis='both', which='major', labelsize=16) 

327 ax[0, ind].set_xticks(dt_list) 

328 ax[0, ind].set_xticklabels(dt_list) 

329 ax[0, ind].set_ylim(1e-15, 1e1) 

330 ax[0, ind].set_yscale('log', base=10) 

331 ax[0, ind].set_xlabel(r'Step size $\Delta t$', fontsize=16) 

332 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16) 

333 ax[0, ind].grid(visible=True) 

334 ax[0, ind].minorticks_off() 

335 ax[0, ind].legend(frameon=True, fontsize=12, loc='lower left') 

336 

337 fig.savefig( 

338 'data/{}_comparison_event{}_M={}.png'.format(prob_cls_name, i + 1, M_fix), dpi=300, bbox_inches='tight' 

339 ) 

340 plt_helper.plt.close(fig) 

341 

342 

343def plotStateFunctionOverTime(u_num, prob_cls_name, M_fix): # pragma: no cover 

344 r""" 

345 Routine that plots the state function over time. 

346 

347 Parameters 

348 ---------- 

349 u_num : dict 

350 Contains the all the data. Dictionary has the structure ``u_num[dt][M][use_SE][use_A]``, 

351 where for each step size ``dt``, for each number of collocation node ``M``, for each 

352 combination of event detection ``use_SE`` and adaptivity ``use_A`` appropriate stuff is stored. 

353 For more details, see ``pySDC.projects.PinTSimE.battery_model.getDataDict``. 

354 prob_cls_name : str 

355 Indicates the name of the problem class to be considered. 

356 M_fix : int 

357 Fixed number of collocation nodes the plot is generated for. 

358 """ 

359 

360 dt_list = u_num.keys() 

361 use_detection = u_num[list(dt_list)[0]][M_fix].keys() 

362 use_adaptivity = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]].keys() 

363 h0 = u_num[list(dt_list)[0]][M_fix][list(use_detection)[0]][list(use_adaptivity)[0]]['state_function'] 

364 n = h0[0].shape[0] 

365 for dt in dt_list: 

366 figsize = (7.5, 5) if n == 1 else (12, 5) 

367 fig, ax = plt_helper.plt.subplots(1, n, figsize=figsize, sharex='col', sharey='row', squeeze=False) 

368 

369 for use_SE in use_detection: 

370 for use_A in use_adaptivity: 

371 t = u_num[dt][M_fix][use_SE][use_A]['t'] 

372 h_val = u_num[dt][M_fix][use_SE][use_A]['state_function'] 

373 

374 linestyle = 'dashdot' if use_A else 'dotted' 

375 for i in range(n): 

376 h = h_val[:] if n == 1 else h_val[:, i] 

377 ax[0, i].set_title(r'$n={}$'.format(i + 1)) 

378 ax[0, i].plot( 

379 t, h, linestyle=linestyle, label='Detection: {}, Adaptivity: {}'.format(use_SE, use_A) 

380 ) 

381 

382 ax[0, i].tick_params(axis='both', which='major', labelsize=16) 

383 ax[0, i].set_ylim(1e-15, 1e0) 

384 ax[0, i].set_yscale('log', base=10) 

385 ax[0, i].set_xlabel(r'Time $t$', fontsize=16) 

386 ax[0, 0].set_ylabel(r'Absolute values of h $|h(v_{C_n}(t))|$', fontsize=16) 

387 ax[0, i].grid(visible=True) 

388 ax[0, i].minorticks_off() 

389 ax[0, i].legend(frameon=True, fontsize=12, loc='lower left') 

390 

391 fig.savefig( 

392 'data/{}_state_function_over_time_dt={}_M={}.png'.format(prob_cls_name, dt, M_fix), 

393 dpi=300, 

394 bbox_inches='tight', 

395 ) 

396 plt_helper.plt.close(fig) 

397 

398 

399if __name__ == "__main__": 

400 run_estimation_check()