Coverage for pySDC/projects/Resilience/accuracy_check.py: 80%

142 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import matplotlib as mpl 

2import matplotlib.pylab as plt 

3import numpy as np 

4 

5from pySDC.helpers.stats_helper import get_sorted 

6from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError 

7from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import ( 

8 EstimateExtrapolationErrorNonMPI, 

9) 

10from pySDC.core.hooks import Hooks 

11from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep 

12from pySDC.projects.Resilience.strategies import merge_descriptions 

13 

14import pySDC.helpers.plot_helper as plt_helper 

15from pySDC.projects.Resilience.piline import run_piline 

16 

17 

18class DoNothing(Hooks): 

19 pass 

20 

21 

22def setup_mpl(font_size=8): 

23 """ 

24 Setup matplotlib to fit in with TeX scipt. 

25 

26 Args: 

27 fontsize (int): Font size 

28 

29 Returns: 

30 None 

31 """ 

32 plt_helper.setup_mpl(reset=True) 

33 # Set up plotting parameters 

34 style_options = { 

35 "axes.labelsize": 12, # LaTeX default is 10pt font. 

36 "legend.fontsize": 13, # Make the legend/label fonts a little smaller 

37 "axes.xmargin": 0.03, 

38 "axes.ymargin": 0.03, 

39 } 

40 mpl.rcParams.update(style_options) 

41 

42 

43def get_results_from_stats(stats, var, val, hook_class=LogLocalErrorPostStep): 

44 """ 

45 Extract results from the stats are used to compute the order. 

46 

47 Args: 

48 stats (dict): The stats object from a pySDC run 

49 var (str): The variable to compute the order against 

50 val (float): The value of var corresponding to this run 

51 hook_class (pySDC.Hook): A hook such that we know what information is available 

52 

53 Returns: 

54 dict: The information needed for the order plot 

55 """ 

56 results = { 

57 'e_embedded': 0.0, 

58 'e_extrapolated': 0.0, 

59 'e': 0.0, 

60 var: val, 

61 } 

62 

63 if hook_class == LogLocalErrorPostStep: 

64 e_extrapolated = np.array(get_sorted(stats, type='error_extrapolation_estimate'))[:, 1] 

65 e_embedded = np.array(get_sorted(stats, type='error_embedded_estimate'))[:, 1] 

66 e_local = np.array(get_sorted(stats, type='e_local_post_step'))[:, 1] 

67 

68 if len(e_extrapolated[e_extrapolated != [None]]) > 0: 

69 results['e_extrapolated'] = e_extrapolated[e_extrapolated != [None]][-1] 

70 

71 if len(e_local[e_local != [None]]) > 0: 

72 results['e'] = max([e_local[e_local != [None]][-1], np.finfo(float).eps]) 

73 

74 if len(e_embedded[e_embedded != [None]]) > 0: 

75 results['e_embedded'] = e_embedded[e_embedded != [None]][-1] 

76 

77 return results 

78 

79 

80def multiple_runs( 

81 k=5, 

82 serial=True, 

83 Tend_fixed=None, 

84 custom_description=None, 

85 prob=run_piline, 

86 dt_list=None, 

87 hook_class=LogLocalErrorPostStep, 

88 custom_controller_params=None, 

89 var='dt', 

90 avoid_restarts=False, 

91 embedded_error_flavor=None, 

92): 

93 """ 

94 A simple test program to compute the order of accuracy. 

95 

96 Args: 

97 k (int): Number of SDC sweeps 

98 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes 

99 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be 

100 computed since a fixed number of steps will be performed. 

101 custom_description (dict): Custom parameters to pass to the problem 

102 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project) 

103 dt_list (list): A list of values to check the order with 

104 hook_class (pySDC.Hook): A hook for recording relevant information 

105 custom_controller_params (dict): Custom parameters to pass to the problem 

106 var (str): The variable to check the order against 

107 avoid_restarts (bool): Mode of running adaptivity if applicable 

108 embedded_error_flavor (str): Flavor for the estimation of embedded error 

109 

110 Returns: 

111 dict: The errors for different values of var 

112 """ 

113 

114 # assemble list of dt 

115 if dt_list is not None: 

116 pass 

117 elif Tend_fixed: 

118 dt_list = 0.1 * 10.0 ** -(np.arange(3) / 2) 

119 else: 

120 dt_list = 0.01 * 10.0 ** -(np.arange(20) / 10.0) 

121 

122 num_procs = 1 if serial else 5 

123 

124 embedded_error_flavor = ( 

125 embedded_error_flavor if embedded_error_flavor else 'standard' if avoid_restarts else 'linearized' 

126 ) 

127 

128 # perform rest of the tests 

129 for i in range(0, len(dt_list)): 

130 desc = { 

131 'step_params': {'maxiter': k}, 

132 'convergence_controllers': { 

133 EstimateEmbeddedError.get_implementation(flavor=embedded_error_flavor, useMPI=False): {}, 

134 EstimateExtrapolationErrorNonMPI: {'no_storage': not serial}, 

135 }, 

136 } 

137 

138 # setup the variable we check the order against 

139 if var == 'dt': 

140 desc['level_params'] = {'dt': dt_list[i]} 

141 elif var == 'e_tol': 

142 from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

143 

144 desc['convergence_controllers'][Adaptivity] = { 

145 'e_tol': dt_list[i], 

146 'avoid_restarts': avoid_restarts, 

147 'embedded_error_flavor': embedded_error_flavor, 

148 } 

149 

150 if custom_description is not None: 

151 desc = merge_descriptions(desc, custom_description) 

152 Tend = Tend_fixed if Tend_fixed else 30 * dt_list[i] 

153 stats, controller, _ = prob( 

154 custom_description=desc, 

155 num_procs=num_procs, 

156 Tend=Tend, 

157 hook_class=hook_class, 

158 custom_controller_params=custom_controller_params, 

159 ) 

160 

161 level = controller.MS[-1].levels[-1] 

162 e_glob = abs(level.prob.u_exact(t=level.time + level.dt) - level.u[-1]) 

163 e_local = abs(level.prob.u_exact(t=level.time + level.dt, u_init=level.u[0], t_init=level.time) - level.u[-1]) 

164 

165 res_ = get_results_from_stats(stats, var, dt_list[i], hook_class) 

166 res_['e_glob'] = e_glob 

167 res_['e_local'] = e_local 

168 

169 if i == 0: 

170 res = res_.copy() 

171 for key in res.keys(): 

172 res[key] = [res[key]] 

173 else: 

174 for key in res_.keys(): 

175 res[key].append(res_[key]) 

176 return res 

177 

178 

179def plot_order(res, ax, k): 

180 """ 

181 Plot the order using results from `multiple_runs`. 

182 

183 Args: 

184 res (dict): The results from `multiple_runs` 

185 ax: Somewhere to plot 

186 k (int): Number of iterations 

187 

188 Returns: 

189 None 

190 """ 

191 color = plt.rcParams['axes.prop_cycle'].by_key()['color'][k - 2] 

192 

193 key = 'e_local' 

194 order = get_accuracy_order(res, key=key, thresh=1e-11) 

195 label = f'k={k}, p={np.mean(order):.2f}' 

196 ax.loglog(res['dt'], res[key], color=color, ls='-', label=label) 

197 ax.set_xlabel(r'$\Delta t$') 

198 ax.set_ylabel(r'$\epsilon$') 

199 ax.legend(frameon=False, loc='lower right') 

200 

201 

202def plot(res, ax, k, var='dt', keys=None): 

203 """ 

204 Plot the order of various errors using the results from `multiple_runs`. 

205 

206 Args: 

207 results (dict): the dictionary containing the errors 

208 ax: Somewhere to plot 

209 k (int): Number of SDC sweeps 

210 var (str): The variable to compute the order against 

211 keys (list): List of keys to plot from the results 

212 

213 Returns: 

214 None 

215 """ 

216 keys = keys if keys else ['e_embedded', 'e_extrapolated', 'e'] 

217 ls = ['-', ':', '-.'] 

218 color = plt.rcParams['axes.prop_cycle'].by_key()['color'][k - 2] 

219 

220 for i in range(len(keys)): 

221 if all(me == 0.0 for me in res[keys[i]]): 

222 continue 

223 order = get_accuracy_order(res, key=keys[i], var=var) 

224 if keys[i] == 'e_embedded': 

225 label = rf'$k={ {np.mean(order):.2f}} $' 

226 expect_order = k if var == 'dt' else 1.0 

227 assert np.isclose( 

228 np.mean(order), expect_order, atol=4e-1 

229 ), f'Expected embedded error estimate to have order {expect_order} \ 

230 \ 

231but got {np.mean(order):.2f}' 

232 

233 elif keys[i] == 'e_extrapolated': 

234 label = None 

235 expect_order = k + 1 if var == 'dt' else 1 + 1 / k 

236 assert np.isclose( 

237 np.mean(order), expect_order, rtol=3e-1 

238 ), f' \ 

239 Expected extrapolation error estimate to have order \ 

240{expect_order} but got {np.mean(order):.2f}' 

241 else: 

242 label = None 

243 ax.loglog(res[var], res[keys[i]], color=color, ls=ls[i], label=label) 

244 

245 if var == 'dt': 

246 ax.set_xlabel(r'$\Delta t$') 

247 elif var == 'e_tol': 

248 ax.set_xlabel(r'$\epsilon_\mathrm{TOL}$') 

249 else: 

250 ax.set_xlabel(var) 

251 ax.set_ylabel(r'$\epsilon$') 

252 ax.legend(frameon=False, loc='lower right') 

253 

254 

255def get_accuracy_order(results, key='e_embedded', thresh=1e-14, var='dt'): 

256 """ 

257 Routine to compute the order of accuracy in time 

258 

259 Args: 

260 results (dict): the dictionary containing the errors 

261 key (str): The key in the dictionary corresponding to a specific error 

262 thresh (float): A threshold below which values are not entered into the order computation 

263 var (str): The variable to compute the order against 

264 

265 Returns: 

266 the list of orders 

267 """ 

268 

269 # retrieve the list of dt from results 

270 assert var in results, f'ERROR: expecting the list of {var} in the results dictionary' 

271 dt_list = sorted(results[var], reverse=True) 

272 

273 order = [] 

274 # loop over two consecutive errors/dt pairs 

275 for i in range(1, len(dt_list)): 

276 # compute order as log(prev_error/this_error)/log(this_dt/old_dt) <-- depends on the sorting of the list! 

277 try: 

278 if results[key][i] > thresh and results[key][i - 1] > thresh: 

279 order.append(np.log(results[key][i] / results[key][i - 1]) / np.log(dt_list[i] / dt_list[i - 1])) 

280 except TypeError: 

281 print('Type Warning', results[key]) 

282 

283 return order 

284 

285 

286def plot_orders( 

287 ax, 

288 ks, 

289 serial, 

290 Tend_fixed=None, 

291 custom_description=None, 

292 prob=run_piline, 

293 dt_list=None, 

294 custom_controller_params=None, 

295 embedded_error_flavor=None, 

296): 

297 """ 

298 Plot only the local error. 

299 

300 Args: 

301 ax: Somewhere to plot 

302 ks (list): List of sweeps 

303 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes 

304 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be 

305 custom_description (dict): Custom parameters to pass to the problem 

306 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project) 

307 dt_list (list): A list of values to check the order with 

308 custom_controller_params (dict): Custom parameters to pass to the problem 

309 embedded_error_flavor (str): Flavor for the estimation of embedded error 

310 

311 Returns: 

312 None 

313 """ 

314 for i in range(len(ks)): 

315 k = ks[i] 

316 res = multiple_runs( 

317 k=k, 

318 serial=serial, 

319 Tend_fixed=Tend_fixed, 

320 custom_description=custom_description, 

321 prob=prob, 

322 dt_list=dt_list, 

323 hook_class=DoNothing, 

324 custom_controller_params=custom_controller_params, 

325 embedded_error_flavor=embedded_error_flavor, 

326 ) 

327 plot_order(res, ax, k) 

328 

329 

330def plot_all_errors( 

331 ax, 

332 ks, 

333 serial, 

334 Tend_fixed=None, 

335 custom_description=None, 

336 prob=run_piline, 

337 dt_list=None, 

338 custom_controller_params=None, 

339 var='dt', 

340 avoid_restarts=False, 

341 embedded_error_flavor=None, 

342 keys=None, 

343): 

344 """ 

345 Make tests for plotting the error and plot a bunch of error estimates 

346 

347 Args: 

348 ax: Somewhere to plot 

349 ks (list): List of sweeps 

350 serial (bool): Whether to do regular SDC or Multi-step SDC with 5 processes 

351 Tend_fixed (float): The time you want to solve the equation to. If left at `None`, the local error will be 

352 custom_description (dict): Custom parameters to pass to the problem 

353 prob (function): A function that can accept suitable arguments and run a problem (see the Resilience project) 

354 dt_list (list): A list of values to check the order with 

355 custom_controller_params (dict): Custom parameters to pass to the problem 

356 var (str): The variable to compute the order against 

357 avoid_restarts (bool): Mode of running adaptivity if applicable 

358 embedded_error_flavor (str): Flavor for the estimation of embedded error 

359 keys (list): List of keys to plot from the results 

360 

361 Returns: 

362 None 

363 """ 

364 for i in range(len(ks)): 

365 k = ks[i] 

366 res = multiple_runs( 

367 k=k, 

368 serial=serial, 

369 Tend_fixed=Tend_fixed, 

370 custom_description=custom_description, 

371 prob=prob, 

372 dt_list=dt_list, 

373 custom_controller_params=custom_controller_params, 

374 var=var, 

375 avoid_restarts=avoid_restarts, 

376 embedded_error_flavor=embedded_error_flavor, 

377 ) 

378 

379 # visualize results 

380 plot(res, ax, k, var=var, keys=keys) 

381 

382 ax.plot([None, None], color='black', label=r'$\epsilon_\mathrm{embedded}$', ls='-') 

383 ax.plot([None, None], color='black', label=r'$\epsilon_\mathrm{extrapolated}$', ls=':') 

384 ax.plot([None, None], color='black', label=r'$e$', ls='-.') 

385 ax.legend(frameon=False, loc='lower right') 

386 

387 

388def check_order_with_adaptivity(): 

389 """ 

390 Test the order when running adaptivity. 

391 Since we replace the step size with the tolerance, we check the order against this. 

392 

393 Irrespective of the number of sweeps we do, the embedded error estimate should scale linearly with the tolerance, 

394 since it is supposed to match it as closely as possible. 

395 

396 The error estimate for the error of the last sweep, however will depend on the number of sweeps we do. The order 

397 we expect is 1 + 1/k. 

398 """ 

399 setup_mpl() 

400 ks = [3, 2] 

401 for serial in [True, False]: 

402 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3)) 

403 plot_all_errors( 

404 ax, 

405 ks, 

406 serial, 

407 Tend_fixed=5e-1, 

408 var='e_tol', 

409 dt_list=[1e-5, 5e-6], 

410 avoid_restarts=False, 

411 custom_controller_params={'logger_level': 30}, 

412 ) 

413 if serial: 

414 fig.savefig('data/error_estimate_order_adaptivity.png', dpi=300, bbox_inches='tight') 

415 else: 

416 fig.savefig('data/error_estimate_order_adaptivity_parallel.png', dpi=300, bbox_inches='tight') 

417 plt.close(fig) 

418 

419 

420def check_order_against_step_size(): 

421 """ 

422 Check the order versus the step size for different numbers of sweeps. 

423 """ 

424 setup_mpl() 

425 ks = [4, 3, 2] 

426 for serial in [True, False]: 

427 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3)) 

428 

429 plot_all_errors(ax, ks, serial, Tend_fixed=1.0) 

430 

431 if serial: 

432 fig.savefig('data/error_estimate_order.png', dpi=300, bbox_inches='tight') 

433 else: 

434 fig.savefig('data/error_estimate_order_parallel.png', dpi=300, bbox_inches='tight') 

435 plt.close(fig) 

436 

437 

438def main(): 

439 """Run various tests""" 

440 check_order_with_adaptivity() 

441 check_order_against_step_size() 

442 

443 

444if __name__ == "__main__": 

445 main()