Coverage for pySDC/projects/Resilience/piline.py: 87%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import numpy as np 

2import matplotlib.pyplot as plt 

3 

4from pySDC.helpers.stats_helper import get_sorted 

5from pySDC.implementations.problem_classes.Piline import piline 

6from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

8from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity 

9from pySDC.implementations.convergence_controller_classes.hotrod import HotRod 

10from pySDC.projects.Resilience.hook import LogData, hook_collection 

11from pySDC.projects.Resilience.strategies import merge_descriptions 

12 

13 

14def run_piline( 

15 custom_description=None, 

16 num_procs=1, 

17 Tend=20.0, 

18 hook_class=LogData, 

19 fault_stuff=None, 

20 custom_controller_params=None, 

21): 

22 """ 

23 Run a Piline problem with default parameters. 

24 

25 Args: 

26 custom_description (dict): Overwrite presets 

27 num_procs (int): Number of steps for MSSDC 

28 Tend (float): Time to integrate to 

29 hook_class (pySDC.Hook): A hook to store data 

30 fault_stuff (dict): A dictionary with information on how to add faults 

31 custom_controller_params (dict): Overwrite presets 

32 

33 Returns: 

34 dict: The stats object 

35 controller: The controller 

36 Tend: The time that was supposed to be integrated to 

37 """ 

38 

39 # initialize level parameters 

40 level_params = dict() 

41 level_params['dt'] = 5e-2 

42 

43 # initialize sweeper parameters 

44 sweeper_params = dict() 

45 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

46 sweeper_params['num_nodes'] = 3 

47 sweeper_params['QI'] = 'IE' 

48 sweeper_params['QE'] = 'PIC' 

49 

50 problem_params = { 

51 'Vs': 100.0, 

52 'Rs': 1.0, 

53 'C1': 1.0, 

54 'Rpi': 0.2, 

55 'C2': 1.0, 

56 'Lpi': 1.0, 

57 'Rl': 5.0, 

58 } 

59 

60 # initialize step parameters 

61 step_params = dict() 

62 step_params['maxiter'] = 4 

63 

64 # initialize controller parameters 

65 controller_params = dict() 

66 controller_params['logger_level'] = 30 

67 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) 

68 controller_params['mssdc_jac'] = False 

69 

70 if custom_controller_params is not None: 

71 controller_params = {**controller_params, **custom_controller_params} 

72 

73 # fill description dictionary for easy step instantiation 

74 description = dict() 

75 description['problem_class'] = piline # pass problem class 

76 description['problem_params'] = problem_params # pass problem parameters 

77 description['sweeper_class'] = imex_1st_order # pass sweeper 

78 description['sweeper_params'] = sweeper_params # pass sweeper parameters 

79 description['level_params'] = level_params # pass level parameters 

80 description['step_params'] = step_params 

81 

82 if custom_description is not None: 

83 description = merge_descriptions(description, custom_description) 

84 

85 # set time parameters 

86 t0 = 0.0 

87 

88 # instantiate controller 

89 controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description) 

90 

91 # insert faults 

92 if fault_stuff is not None: 

93 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

94 

95 rnd_args = {'iteration': 4} 

96 args = {'time': 2.5, 'target': 0} 

97 prepare_controller_for_faults(controller, fault_stuff, rnd_args, args) 

98 

99 # get initial values on finest level 

100 P = controller.MS[0].levels[0].prob 

101 uinit = P.u_exact(t0) 

102 

103 # call main function to get things done... 

104 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

105 return stats, controller, Tend 

106 

107 

108def get_data(stats, recomputed=False): 

109 """ 

110 Extract useful data from the stats. 

111 

112 Args: 

113 stats (pySDC.stats): The stats object of the run 

114 recomputed (bool): Whether to exclude values that don't contribute to the final solution or not 

115 

116 Returns: 

117 dict: Data 

118 """ 

119 data = { 

120 'v1': np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=recomputed)]), 

121 'v2': np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=recomputed)]), 

122 'p3': np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=recomputed)]), 

123 't': np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)]), 

124 'dt': np.array([me[1] for me in get_sorted(stats, type='dt', recomputed=recomputed)]), 

125 't_dt': np.array([me[0] for me in get_sorted(stats, type='dt', recomputed=recomputed)]), 

126 'e_em': np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed))[:, 1], 

127 'e_ex': np.array(get_sorted(stats, type='error_extrapolation_estimate', recomputed=recomputed))[:, 1], 

128 'restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1], 

129 't_restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 0], 

130 'sweeps': np.array(get_sorted(stats, type='sweeps', recomputed=None))[:, 1], 

131 } 

132 data['ready'] = np.logical_and(data['e_ex'] != np.array(None), data['e_em'] != np.array(None)) 

133 data['restart_times'] = data['t_restarts'][data['restarts'] > 0] 

134 return data 

135 

136 

137def plot_error(data, ax, use_adaptivity=True, plot_restarts=False): 

138 """ 

139 Plot the embedded and extrapolated error estimates. 

140 

141 Args: 

142 data (dict): Data prepared from stats by `get_data` 

143 use_adaptivity (bool): Whether adaptivity was used 

144 plot_restarts (bool): Whether to plot vertical lines for restarts 

145 

146 Returns: 

147 None 

148 """ 

149 setup_mpl_from_accuracy_check() 

150 ax.plot(data['t_dt'], data['dt'], color='black') 

151 

152 e_ax = ax.twinx() 

153 e_ax.plot(data['t'], data['e_em'], label=r'$\epsilon_\mathrm{embedded}$') 

154 e_ax.plot(data['t'][data['ready']], data['e_ex'][data['ready']], label=r'$\epsilon_\mathrm{extrapolated}$', ls='--') 

155 e_ax.plot( 

156 data['t'][data['ready']], 

157 abs(data['e_em'][data['ready']] - data['e_ex'][data['ready']]), 

158 label='difference', 

159 ls='-.', 

160 ) 

161 

162 if plot_restarts: 

163 [ax.axvline(t_restart, ls='-.', color='black', alpha=0.5) for t_restart in data['restart_times']] 

164 

165 e_ax.plot([None, None], label=r'$\Delta t$', color='black') 

166 e_ax.set_yscale('log') 

167 if use_adaptivity: 

168 e_ax.legend(frameon=False, loc='upper left') 

169 else: 

170 e_ax.legend(frameon=False, loc='upper right') 

171 e_ax.set_ylim((7.367539795147197e-12, 1.109667868425781e-05)) 

172 ax.set_ylim((0.012574322653781072, 0.10050387672423527)) 

173 

174 ax.set_xlabel('Time') 

175 ax.set_ylabel(r'$\Delta t$') 

176 ax.set_xlabel('Time') 

177 

178 

179def setup_mpl_from_accuracy_check(): 

180 """ 

181 Change matplotlib parameters to conform to LaTeX style. 

182 """ 

183 from pySDC.projects.Resilience.accuracy_check import setup_mpl 

184 

185 setup_mpl() 

186 

187 

188def plot_solution(data, ax): 

189 """ 

190 Plot the solution. 

191 

192 Args: 

193 data (dict): Data prepared from stats by `get_data` 

194 ax: Somewhere to plot 

195 

196 Returns: 

197 None 

198 """ 

199 setup_mpl_from_accuracy_check() 

200 ax.plot(data['t'], data['v1'], label='v1', ls='-') 

201 ax.plot(data['t'], data['v2'], label='v2', ls='--') 

202 ax.plot(data['t'], data['p3'], label='p3', ls='-.') 

203 ax.legend(frameon=False) 

204 ax.set_xlabel('Time') 

205 

206 

207def check_solution(data, use_adaptivity, num_procs, generate_reference=False): 

208 """ 

209 Check the solution against a hard coded reference. 

210 

211 Args: 

212 data (dict): Data prepared from stats by `get_data` 

213 use_adaptivity (bool): Whether adaptivity was used 

214 num_procs (int): Number of steps for MSSDC 

215 generate_reference (bool): Instead of comparing to reference, print a new reference to the console 

216 

217 Returns: 

218 None 

219 """ 

220 if use_adaptivity and num_procs == 1: 

221 error_msg = 'Error when using adaptivity in serial:' 

222 expected = { 

223 'v1': 83.88330442715265, 

224 'v2': 80.62692930055763, 

225 'p3': 16.13594155613822, 

226 'e_em': 4.922608098922865e-09, 

227 'e_ex': 4.4120077421613226e-08, 

228 'dt': 0.05, 

229 'restarts': 1.0, 

230 'sweeps': 2416.0, 

231 't': 20.03656747407325, 

232 } 

233 

234 elif use_adaptivity and num_procs == 4: 

235 error_msg = 'Error when using adaptivity in parallel:' 

236 expected = { 

237 'v1': 83.88320903115796, 

238 'v2': 80.6269822629629, 

239 'p3': 16.136084724243805, 

240 'e_em': 4.0668446388281154e-09, 

241 'e_ex': 4.901094641240463e-09, 

242 'dt': 0.05, 

243 'restarts': 48.0, 

244 'sweeps': 2592.0, 

245 't': 20.041499821475185, 

246 } 

247 

248 elif not use_adaptivity and num_procs == 4: 

249 error_msg = 'Error with fixed step size in parallel:' 

250 expected = { 

251 'v1': 83.88400128006428, 

252 'v2': 80.62656202423844, 

253 'p3': 16.134849781053525, 

254 'e_em': 4.277040943634347e-09, 

255 'e_ex': 4.9707053288253756e-09, 

256 'dt': 0.05, 

257 'restarts': 0.0, 

258 'sweeps': 1600.0, 

259 't': 20.00000000000015, 

260 } 

261 

262 elif not use_adaptivity and num_procs == 1: 

263 error_msg = 'Error with fixed step size in serial:' 

264 expected = { 

265 'v1': 83.88400149770143, 

266 'v2': 80.62656173487008, 

267 'p3': 16.134849851184736, 

268 'e_em': 4.977994905175365e-09, 

269 'e_ex': 5.048084913047097e-09, 

270 'dt': 0.05, 

271 'restarts': 0.0, 

272 'sweeps': 1600.0, 

273 't': 20.00000000000015, 

274 } 

275 

276 got = { 

277 'v1': data['v1'][-1], 

278 'v2': data['v2'][-1], 

279 'p3': data['p3'][-1], 

280 'e_em': data['e_em'][-1], 

281 'e_ex': data['e_ex'][data['e_ex'] != [None]][-1], 

282 'dt': data['dt'][-1], 

283 'restarts': data['restarts'].sum(), 

284 'sweeps': data['sweeps'].sum(), 

285 't': data['t'][-1], 

286 } 

287 

288 if generate_reference: 

289 print(f'Adaptivity: {use_adaptivity}, num_procs={num_procs}') 

290 print('expected = {') 

291 for k in got.keys(): 

292 v = got[k] 

293 if type(v) in [list, np.ndarray]: 

294 print(f' \'{k}\': {v[v!=[None]][-1]},') 

295 else: 

296 print(f' \'{k}\': {v},') 

297 print('}') 

298 

299 for k in expected.keys(): 

300 assert np.isclose( 

301 expected[k], got[k], rtol=1e-4 

302 ), f'{error_msg} Expected {k}={expected[k]:.4e}, got {k}={got[k]:.4e}' 

303 

304 

305def residual_adaptivity(plot=False): 

306 """ 

307 Make a run with adaptivity based on the residual. 

308 """ 

309 from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityResidual 

310 

311 max_res = 1e-8 

312 custom_description = {'convergence_controllers': {}} 

313 custom_description['convergence_controllers'][AdaptivityResidual] = { 

314 'e_tol': max_res, 

315 'e_tol_low': max_res / 10, 

316 } 

317 stats, _, _ = run_piline(custom_description, num_procs=1) 

318 

319 residual = get_sorted(stats, type='residual_post_step', recomputed=False) 

320 dt = get_sorted(stats, type='dt', recomputed=False) 

321 

322 if plot: 

323 fig, ax = plt.subplots() 

324 dt_ax = ax.twinx() 

325 

326 ax.plot([me[0] for me in residual], [me[1] for me in residual]) 

327 dt_ax.plot([me[0] for me in dt], [me[1] for me in dt], color='black') 

328 plt.show() 

329 

330 max_residual = max([me[1] for me in residual]) 

331 assert max_residual < max_res, f'Max. allowed residual is {max_res:.2e}, but got {max_residual:.2e}!' 

332 dt_std = np.std([me[1] for me in dt]) 

333 assert dt_std != 0, f'Expected the step size to change, but standard deviation is {dt_std:.2e}!' 

334 

335 

336def main(): 

337 """ 

338 Make a variety of tests to see if Hot Rod and Adaptivity work in serial as well as MSSDC. 

339 """ 

340 generate_reference = False 

341 

342 for use_adaptivity in [True, False]: 

343 custom_description = {'convergence_controllers': {}} 

344 if use_adaptivity: 

345 custom_description['convergence_controllers'][Adaptivity] = { 

346 'e_tol': 1e-7, 

347 'embedded_error_flavor': 'linearized', 

348 } 

349 

350 for num_procs in [1, 4]: 

351 custom_description['convergence_controllers'][HotRod] = {'HotRod_tol': 1, 'no_storage': num_procs > 1} 

352 stats, _, _ = run_piline(custom_description, num_procs=num_procs) 

353 data = get_data(stats, recomputed=False) 

354 fig, ax = plt.subplots(1, 1, figsize=(3.5, 3)) 

355 plot_error(data, ax, use_adaptivity) 

356 if use_adaptivity: 

357 fig.savefig(f'data/piline_hotrod_adaptive_{num_procs}procs.png', bbox_inches='tight', dpi=300) 

358 else: 

359 fig.savefig(f'data/piline_hotrod_{num_procs}procs.png', bbox_inches='tight', dpi=300) 

360 if use_adaptivity and num_procs == 4: 

361 sol_fig, sol_ax = plt.subplots(1, 1, figsize=(3.5, 3)) 

362 plot_solution(data, sol_ax) 

363 sol_fig.savefig('data/piline_solution_adaptive.png', bbox_inches='tight', dpi=300) 

364 plt.close(sol_fig) 

365 check_solution(data, use_adaptivity, num_procs, generate_reference) 

366 plt.close(fig) 

367 

368 

369if __name__ == "__main__": 

370 residual_adaptivity() 

371 main()