Coverage for pySDC/projects/Resilience/accuracy_check.py: 80%

142 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import matplotlib as mpl 

2import matplotlib.pylab as plt 

3import numpy as np 

4 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError 

7from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import ( 

8 EstimateExtrapolationErrorNonMPI, 

9) 

10from pySDC.core.hooks import Hooks 

11from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep 

12from pySDC.projects.Resilience.strategies import merge_descriptions 

13 

14import pySDC.helpers.plot_helper as plt_helper 

15from pySDC.projects.Resilience.piline import run_piline 

16 

17 

18class DoNothing(Hooks): 

19 pass 

20 

21 

22def setup_mpl(font_size=8): 

23 """ 

24 Setup matplotlib to fit in with TeX scipt. 

25 

26 Args: 

27 fontsize (int): Font size 

28 

29 Returns: 

30 None 

31 """ 

32 plt_helper.setup_mpl(reset=True) 

33 # Set up plotting parameters 

34 style_options = { 

35 "axes.labelsize": 12, # LaTeX default is 10pt font. 

36 "legend.fontsize": 13, # Make the legend/label fonts a little smaller 

37 "axes.xmargin": 0.03, 

38 "axes.ymargin": 0.03, 

39 } 

40 mpl.rcParams.update(style_options) 

41 

42 

43def get_results_from_stats(stats, var, val, hook_class=LogLocalErrorPostStep): 

44 """ 

45 Extract results from the stats are used to compute the order. 

46 

47 Args: 

48 stats (dict): The stats object from a pySDC run 

49 var (str): The variable to compute the order against 

50 val (float): The value of var corresponding to this run 

51 hook_class (pySDC.Hook): A hook such that we know what information is available 

52 

53 Returns: 

54 dict: The information needed for the order plot 

55 """ 

56 results = { 

57 'e_embedded': 0.0, 

58 'e_extrapolated': 0.0, 

59 'e': 0.0, 

60 var: val, 

61 } 

62 

63 if hook_class == LogLocalErrorPostStep: 

64 e_extrapolated = np.array(get_sorted(stats, type='error_extrapolation_estimate'))[:, 1] 

65 e_embedded = np.array(get_sorted(stats, type='error_embedded_estimate'))[:, 1] 

66 e_local = np.array(get_sorted(stats, type='e_local_post_step'))[:, 1] 

67 

68 if len(e_extrapolated[e_extrapolated != [None]]) > 0: 

69 results['e_extrapolated'] = e_extrapolated[e_extrapolated != [None]][-1] 

70 

71 if len(e_local[e_local != [None]]) > 0: 

72 results['e'] = max([e_local[e_local != [None]][-1], np.finfo(float).eps]) 

73 

74 if len(e_embedded[e_embedded != [None]]) > 0: 

75 results['e_embedded'] = e_embedded[e_embedded != [None]][-1] 

76 

77 return results 

78 

79 

80def multiple_runs( 

81 k=5, 

82 serial=True, 

83 Tend_fixed=None, 

84 custom_description=None, 

85 prob=run_piline, 

86 dt_list=None, 

87 hook_class=LogLocalErrorPostStep, 

88 custom_controller_params=None, 

89 var='dt', 

90 avoid_restarts=False, 

91 embedded_error_flavor=None, 

92): 

93 """ 

94 A simple test program to compute the order of accuracy. 

95 

96 Args: 

97 k (int): Number of SDC sweeps 

98 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes 

99 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be 

100 computed since a fixed number of steps will be performed. 

101 custom_description (dict): Custom parameters to pass to the problem 

102 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project) 

103 dt_list (list): A list of values to check the order with 

104 hook_class (pySDC.Hook): A hook for recording relevant information 

105 custom_controller_params (dict): Custom parameters to pass to the problem 

106 var (str): The variable to check the order against 

107 avoid_restarts (bool): Mode of running adaptivity if applicable 

108 embedded_error_flavor (str): Flavor for the estimation of embedded error 

109 

110 Returns: 

111 dict: The errors for different values of var 

112 """ 

113 

114 # assemble list of dt 

115 if dt_list is not None: 

116 pass 

117 elif Tend_fixed: 

118 dt_list = 0.1 * 10.0 ** -(np.arange(3) / 2) 

119 else: 

120 dt_list = 0.01 * 10.0 ** -(np.arange(20) / 10.0) 

121 

122 num_procs = 1 if serial else 5 

123 

124 embedded_error_flavor = ( 

125 embedded_error_flavor if embedded_error_flavor else 'standard' if avoid_restarts else 'linearized' 

126 ) 

127 

128 # perform rest of the tests 

129 for i in range(0, len(dt_list)): 

130 desc = { 

131 'step_params': {'maxiter': k}, 

132 'convergence_controllers': { 

133 EstimateEmbeddedError.get_implementation(flavor=embedded_error_flavor, useMPI=False): {}, 

134 EstimateExtrapolationErrorNonMPI: {'no_storage': not serial}, 

135 }, 

136 } 

137 

138 # setup the variable we check the order against 

139 if var == 'dt': 

140 desc['level_params'] = {'dt': dt_list[i]} 

141 elif var == 'e_tol': 

142 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

143 

144 desc['convergence_controllers'][Adaptivity] = { 

145 'e_tol': dt_list[i], 

146 'avoid_restarts': avoid_restarts, 

147 'embedded_error_flavor': embedded_error_flavor, 

148 } 

149 

150 if custom_description is not None: 

151 desc = merge_descriptions(desc, custom_description) 

152 Tend = Tend_fixed if Tend_fixed else 30 * dt_list[i] 

153 stats, controller, _ = prob( 

154 custom_description=desc, 

155 num_procs=num_procs, 

156 Tend=Tend, 

157 hook_class=hook_class, 

158 custom_controller_params=custom_controller_params, 

159 ) 

160 

161 level = controller.MS[-1].levels[-1] 

162 e_glob = abs(level.prob.u_exact(t=level.time + level.dt) - level.u[-1]) 

163 e_local = abs(level.prob.u_exact(t=level.time + level.dt, u_init=level.u[0], t_init=level.time) - level.u[-1]) 

164 

165 res_ = get_results_from_stats(stats, var, dt_list[i], hook_class) 

166 res_['e_glob'] = e_glob 

167 res_['e_local'] = e_local 

168 

169 if i == 0: 

170 res = res_.copy() 

171 for key in res.keys(): 

172 res[key] = [res[key]] 

173 else: 

174 for key in res_.keys(): 

175 res[key].append(res_[key]) 

176 return res 

177 

178 

179def plot_order(res, ax, k): 

180 """ 

181 Plot the order using results from `multiple_runs`. 

182 

183 Args: 

184 res (dict): The results from `multiple_runs` 

185 ax: Somewhere to plot 

186 k (int): Number of iterations 

187 

188 Returns: 

189 None 

190 """ 

191 color = plt.rcParams['axes.prop_cycle'].by_key()['color'][k - 2] 

192 

193 key = 'e_local' 

194 order = get_accuracy_order(res, key=key, thresh=1e-11) 

195 label = f'k={k}, p={np.mean(order):.2f}' 

196 ax.loglog(res['dt'], res[key], color=color, ls='-', label=label) 

197 ax.set_xlabel(r'$\Delta t$') 

198 ax.set_ylabel(r'$\epsilon$') 

199 ax.legend(frameon=False, loc='lower right') 

200 

201 

202def plot(res, ax, k, var='dt', keys=None): 

203 """ 

204 Plot the order of various errors using the results from `multiple_runs`. 

205 

206 Args: 

207 results (dict): the dictionary containing the errors 

208 ax: Somewhere to plot 

209 k (int): Number of SDC sweeps 

210 var (str): The variable to compute the order against 

211 keys (list): List of keys to plot from the results 

212 

213 Returns: 

214 None 

215 """ 

216 keys = keys if keys else ['e_embedded', 'e_extrapolated', 'e'] 

217 ls = ['-', ':', '-.'] 

218 color = plt.rcParams['axes.prop_cycle'].by_key()['color'][k - 2] 

219 

220 for i in range(len(keys)): 

221 if all(me == 0.0 for me in res[keys[i]]): 

222 continue 

223 order = get_accuracy_order(res, key=keys[i], var=var) 

224 if keys[i] == 'e_embedded': 

225 label = rf'$k={ {np.mean(order):.2f}} $' 

226 expect_order = k if var == 'dt' else 1.0 

227 assert np.isclose( 

228 np.mean(order), expect_order, atol=4e-1 

229 ), f'Expected embedded error estimate to have order {expect_order} \ 

230but got {np.mean(order):.2f}' 

231 

232 elif keys[i] == 'e_extrapolated': 

233 label = None 

234 expect_order = k + 1 if var == 'dt' else 1 + 1 / k 

235 assert np.isclose( 

236 np.mean(order), expect_order, rtol=3e-1 

237 ), f'Expected extrapolation error estimate to have order \ 

238{expect_order} but got {np.mean(order):.2f}' 

239 else: 

240 label = None 

241 ax.loglog(res[var], res[keys[i]], color=color, ls=ls[i], label=label) 

242 

243 if var == 'dt': 

244 ax.set_xlabel(r'$\Delta t$') 

245 elif var == 'e_tol': 

246 ax.set_xlabel(r'$\epsilon_\mathrm{TOL}$') 

247 else: 

248 ax.set_xlabel(var) 

249 ax.set_ylabel(r'$\epsilon$') 

250 ax.legend(frameon=False, loc='lower right') 

251 

252 

253def get_accuracy_order(results, key='e_embedded', thresh=1e-14, var='dt'): 

254 """ 

255 Routine to compute the order of accuracy in time 

256 

257 Args: 

258 results (dict): the dictionary containing the errors 

259 key (str): The key in the dictionary corresponding to a specific error 

260 thresh (float): A threshold below which values are not entered into the order computation 

261 var (str): The variable to compute the order against 

262 

263 Returns: 

264 the list of orders 

265 """ 

266 

267 # retrieve the list of dt from results 

268 assert var in results, f'ERROR: expecting the list of {var} in the results dictionary' 

269 dt_list = sorted(results[var], reverse=True) 

270 

271 order = [] 

272 # loop over two consecutive errors/dt pairs 

273 for i in range(1, len(dt_list)): 

274 # compute order as log(prev_error/this_error)/log(this_dt/old_dt) <-- depends on the sorting of the list! 

275 try: 

276 if results[key][i] > thresh and results[key][i - 1] > thresh: 

277 order.append(np.log(results[key][i] / results[key][i - 1]) / np.log(dt_list[i] / dt_list[i - 1])) 

278 except TypeError: 

279 print('Type Warning', results[key]) 

280 

281 return order 

282 

283 

284def plot_orders( 

285 ax, 

286 ks, 

287 serial, 

288 Tend_fixed=None, 

289 custom_description=None, 

290 prob=run_piline, 

291 dt_list=None, 

292 custom_controller_params=None, 

293 embedded_error_flavor=None, 

294): 

295 """ 

296 Plot only the local error. 

297 

298 Args: 

299 ax: Somewhere to plot 

300 ks (list): List of sweeps 

301 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes 

302 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be 

303 custom_description (dict): Custom parameters to pass to the problem 

304 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project) 

305 dt_list (list): A list of values to check the order with 

306 custom_controller_params (dict): Custom parameters to pass to the problem 

307 embedded_error_flavor (str): Flavor for the estimation of embedded error 

308 

309 Returns: 

310 None 

311 """ 

312 for i in range(len(ks)): 

313 k = ks[i] 

314 res = multiple_runs( 

315 k=k, 

316 serial=serial, 

317 Tend_fixed=Tend_fixed, 

318 custom_description=custom_description, 

319 prob=prob, 

320 dt_list=dt_list, 

321 hook_class=DoNothing, 

322 custom_controller_params=custom_controller_params, 

323 embedded_error_flavor=embedded_error_flavor, 

324 ) 

325 plot_order(res, ax, k) 

326 

327 

328def plot_all_errors( 

329 ax, 

330 ks, 

331 serial, 

332 Tend_fixed=None, 

333 custom_description=None, 

334 prob=run_piline, 

335 dt_list=None, 

336 custom_controller_params=None, 

337 var='dt', 

338 avoid_restarts=False, 

339 embedded_error_flavor=None, 

340 keys=None, 

341): 

342 """ 

343 Make tests for plotting the error and plot a bunch of error estimates 

344 

345 Args: 

346 ax: Somewhere to plot 

347 ks (list): List of sweeps 

348 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes 

349 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be 

350 custom_description (dict): Custom parameters to pass to the problem 

351 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project) 

352 dt_list (list): A list of values to check the order with 

353 custom_controller_params (dict): Custom parameters to pass to the problem 

354 var (str): The variable to compute the order against 

355 avoid_restarts (bool): Mode of running adaptivity if applicable 

356 embedded_error_flavor (str): Flavor for the estimation of embedded error 

357 keys (list): List of keys to plot from the results 

358 

359 Returns: 

360 None 

361 """ 

362 for i in range(len(ks)): 

363 k = ks[i] 

364 res = multiple_runs( 

365 k=k, 

366 serial=serial, 

367 Tend_fixed=Tend_fixed, 

368 custom_description=custom_description, 

369 prob=prob, 

370 dt_list=dt_list, 

371 custom_controller_params=custom_controller_params, 

372 var=var, 

373 avoid_restarts=avoid_restarts, 

374 embedded_error_flavor=embedded_error_flavor, 

375 ) 

376 

377 # visualize results 

378 plot(res, ax, k, var=var, keys=keys) 

379 

380 ax.plot([None, None], color='black', label=r'$\epsilon_\mathrm{embedded}$', ls='-') 

381 ax.plot([None, None], color='black', label=r'$\epsilon_\mathrm{extrapolated}$', ls=':') 

382 ax.plot([None, None], color='black', label=r'$e$', ls='-.') 

383 ax.legend(frameon=False, loc='lower right') 

384 

385 

386def check_order_with_adaptivity(): 

387 """ 

388 Test the order when running adaptivity. 

389 Since we replace the step size with the tolerance, we check the order against this. 

390 

391 Irrespective of the number of sweeps we do, the embedded error estimate should scale linearly with the tolerance, 

392 since it is supposed to match it as closely as possible. 

393 

394 The error estimate for the error of the last sweep, however will depend on the number of sweeps we do. The order 

395 we expect is 1 + 1/k. 

396 """ 

397 setup_mpl() 

398 ks = [3, 2] 

399 for serial in [True, False]: 

400 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3)) 

401 plot_all_errors( 

402 ax, 

403 ks, 

404 serial, 

405 Tend_fixed=5e-1, 

406 var='e_tol', 

407 dt_list=[1e-5, 5e-6], 

408 avoid_restarts=False, 

409 custom_controller_params={'logger_level': 30}, 

410 ) 

411 if serial: 

412 fig.savefig('data/error_estimate_order_adaptivity.png', dpi=300, bbox_inches='tight') 

413 else: 

414 fig.savefig('data/error_estimate_order_adaptivity_parallel.png', dpi=300, bbox_inches='tight') 

415 plt.close(fig) 

416 

417 

418def check_order_against_step_size(): 

419 """ 

420 Check the order versus the step size for different numbers of sweeps. 

421 """ 

422 setup_mpl() 

423 ks = [4, 3, 2] 

424 for serial in [True, False]: 

425 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3)) 

426 

427 plot_all_errors(ax, ks, serial, Tend_fixed=1.0) 

428 

429 if serial: 

430 fig.savefig('data/error_estimate_order.png', dpi=300, bbox_inches='tight') 

431 else: 

432 fig.savefig('data/error_estimate_order_parallel.png', dpi=300, bbox_inches='tight') 

433 plt.close(fig) 

434 

435 

436def main(): 

437 """Run various tests""" 

438 check_order_with_adaptivity() 

439 check_order_against_step_size() 

440 

441 

442if __name__ == "__main__": 

443 main()