Coverage for pySDC/projects/Resilience/AC.py: 16%

111 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1# script to run an Allen-Cahn problem 

2from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_fullyimplicit, allencahn_semiimplicit 

3from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex 

4from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

5from pySDC.core.Hooks import hooks 

6from pySDC.projects.Resilience.hook import hook_collection, LogData 

7from pySDC.projects.Resilience.strategies import merge_descriptions 

8from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient 

9import matplotlib.pyplot as plt 

10import numpy as np 

11 

12from pySDC.core.Errors import ConvergenceError 

13 

14 

15def run_AC( 

16 custom_description=None, 

17 num_procs=1, 

18 Tend=1e-2, 

19 hook_class=LogData, 

20 fault_stuff=None, 

21 custom_controller_params=None, 

22 imex=False, 

23 u0=None, 

24 t0=None, 

25 use_MPI=False, 

26 live_plot=False, 

27 FFT=True, 

28 **kwargs, 

29): 

30 """ 

31 Args: 

32 custom_description (dict): Overwrite presets 

33 num_procs (int): Number of steps for MSSDC 

34 Tend (float): Time to integrate to 

35 hook_class (pySDC.Hook): A hook to store data 

36 fault_stuff (dict): A dictionary with information on how to add faults 

37 custom_controller_params (dict): Overwrite presets 

38 imex (bool): Solve the problem IMEX or fully implicit 

39 u0 (dtype_u): Initial value 

40 t0 (float): Starting time 

41 use_MPI (bool): Whether or not to use MPI 

42 

43 Returns: 

44 dict: The stats object 

45 controller: The controller 

46 bool: If the code crashed 

47 """ 

48 if custom_description is not None: 

49 problem_params = custom_description.get('problem_params', {}) 

50 if 'imex' in problem_params.keys(): 

51 imex = problem_params['imex'] 

52 problem_params.pop('imex', None) 

53 if 'FFT' in problem_params.keys(): 

54 FFT = problem_params['FFT'] 

55 problem_params.pop('FFT', None) 

56 

57 # import problem and sweeper class 

58 if FFT: 

59 from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex as problem_class 

60 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class 

61 elif imex: 

62 from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_semiimplicit as problem_class 

63 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class 

64 else: 

65 from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_fullyimplicit as problem_class 

66 from pySDC.projects.Resilience.sweepers import generic_implicit_efficient as sweeper_class 

67 

68 level_params = {} 

69 level_params['dt'] = 1e-4 

70 level_params['restol'] = 1e-8 

71 

72 sweeper_params = {} 

73 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

74 sweeper_params['num_nodes'] = 3 

75 sweeper_params['QI'] = 'LU' 

76 sweeper_params['QE'] = 'PIC' 

77 

78 # problem params 

79 fd_params = { 

80 'newton_tol': 1e-9, 

81 'order': 2, 

82 } 

83 problem_params = { 

84 'nvars': (128, 128), 

85 'init_type': 'circle', 

86 } 

87 if not FFT: 

88 problem_params = {**problem_params, **fd_params} 

89 

90 step_params = {} 

91 step_params['maxiter'] = 5 

92 

93 controller_params = {} 

94 controller_params['logger_level'] = 30 

95 controller_params['hook_class'] = ( 

96 hook_collection + (hook_class if type(hook_class) == list else [hook_class]) + ([LivePlot] if live_plot else []) 

97 ) 

98 controller_params['mssdc_jac'] = False 

99 

100 if custom_controller_params is not None: 

101 controller_params = {**controller_params, **custom_controller_params} 

102 

103 description = {} 

104 description['problem_class'] = problem_class 

105 description['problem_params'] = problem_params 

106 description['sweeper_class'] = sweeper_class 

107 description['sweeper_params'] = sweeper_params 

108 description['level_params'] = level_params 

109 description['step_params'] = step_params 

110 

111 if custom_description is not None: 

112 description = merge_descriptions(description, custom_description) 

113 

114 t0 = 0.0 if t0 is None else t0 

115 

116 controller_args = { 

117 'controller_params': controller_params, 

118 'description': description, 

119 } 

120 if use_MPI: 

121 from mpi4py import MPI 

122 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

123 

124 comm = kwargs.get('comm', MPI.COMM_WORLD) 

125 controller = controller_MPI(**controller_args, comm=comm) 

126 P = controller.S.levels[0].prob 

127 else: 

128 controller = controller_nonMPI(**controller_args, num_procs=num_procs) 

129 P = controller.MS[0].levels[0].prob 

130 

131 uinit = P.u_exact(t0) if u0 is None else u0 

132 

133 if fault_stuff is not None: 

134 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

135 

136 prepare_controller_for_faults(controller, fault_stuff) 

137 

138 crash = False 

139 try: 

140 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

141 except ConvergenceError as e: 

142 print(f'Warning: Premature termination!: {e}') 

143 stats = controller.return_stats() 

144 crash = True 

145 return stats, controller, crash 

146 

147 

148def plot_solution(stats): # pragma: no cover 

149 import matplotlib.pyplot as plt 

150 from pySDC.helpers.stats_helper import get_sorted 

151 

152 fig, ax = plt.subplots(1, 1) 

153 

154 u = get_sorted(stats, type='u', recomputed=False) 

155 for me in u: # pun intended 

156 ax.imshow(me[1], vmin=-1, vmax=1) 

157 ax.set_title(f't={me[0]:.2e}') 

158 plt.pause(1e-1) 

159 

160 plt.show() 

161 

162 

163class LivePlot(hooks): # pragma: no cover 

164 def __init__(self): 

165 super().__init__() 

166 self.fig, self.axs = plt.subplots(1, 3, figsize=(12, 4)) 

167 self.radius = [] 

168 self.exact_radius = [] 

169 self.t = [] 

170 self.dt = [] 

171 

172 def post_step(self, step, level_number): 

173 super().post_step(step, level_number) 

174 L = step.levels[level_number] 

175 self.t += [step.time + step.dt] 

176 

177 # plot solution 

178 self.axs[0].cla() 

179 if len(L.uend.shape) > 1: 

180 self.axs[0].imshow(L.uend, vmin=0.0, vmax=1.0) 

181 

182 # plot radius 

183 self.axs[1].cla() 

184 radius, _ = LogRadius.compute_radius(step.levels[level_number]) 

185 exact_radius = LogRadius.exact_radius(step.levels[level_number]) 

186 

187 self.radius += [radius] 

188 self.exact_radius += [exact_radius] 

189 self.axs[1].plot(self.t, self.exact_radius, label='exact') 

190 self.axs[1].plot(self.t, self.radius, label='numerical') 

191 self.axs[1].set_ylim([0, 0.26]) 

192 self.axs[1].set_xlim([0, 0.03]) 

193 self.axs[1].legend(frameon=False) 

194 self.axs[1].set_title(r'Radius') 

195 else: 

196 self.axs[0].plot(L.prob.xvalues, L.prob.u_exact(t=L.time + L.dt), label='exact') 

197 self.axs[0].plot(L.prob.xvalues, L.uend, label='numerical') 

198 self.axs[0].set_title(f't = {step.time + step.dt:.2e}') 

199 

200 # plot step size 

201 self.axs[2].cla() 

202 self.dt += [step.dt] 

203 self.axs[2].plot(self.t, self.dt) 

204 self.axs[2].set_yscale('log') 

205 self.axs[2].axhline(step.levels[level_number].prob.eps ** 2, label=r'$\epsilon^2$', color='black', ls='--') 

206 self.axs[2].legend(frameon=False) 

207 self.axs[2].set_xlim([0, 0.03]) 

208 self.axs[2].set_title(r'$\Delta t$') 

209 

210 if step.status.restart: 

211 for me in [self.radius, self.exact_radius, self.t, self.dt]: 

212 try: 

213 me.pop(-1) 

214 except (TypeError, IndexError): 

215 pass 

216 

217 plt.pause(1e-9) 

218 

219 

220class LogRadius(hooks): 

221 @staticmethod 

222 def compute_radius(L): 

223 c = np.count_nonzero(L.u[0] > 0.0) 

224 radius = np.sqrt(c / np.pi) * L.prob.dx 

225 

226 rows, cols = np.where(L.u[0] > 0.0) 

227 

228 rows1 = np.where(L.u[0][int((L.prob.init[0][0]) / 2), : int((L.prob.init[0][0]) / 2)] > -0.99) 

229 rows2 = np.where(L.u[0][int((L.prob.init[0][0]) / 2), : int((L.prob.init[0][0]) / 2)] < 0.99) 

230 interface_width = (rows2[0][-1] - rows1[0][0]) * L.prob.dx / L.prob.eps 

231 

232 return radius, interface_width 

233 

234 @staticmethod 

235 def exact_radius(L): 

236 init_radius = L.prob.radius 

237 return np.sqrt(max(init_radius**2 - 2.0 * (L.time + L.dt), 0)) 

238 

239 def pre_run(self, step, level_number): 

240 """ 

241 Overwrite standard pre run hook 

242 

243 Args: 

244 step (pySDC.Step.step): the current step 

245 level_number (int): the current level number 

246 """ 

247 super().pre_run(step, level_number) 

248 L = step.levels[0] 

249 

250 radius, interface_width = self.compute_radius(L) 

251 exact_radius = self.exact_radius(L) 

252 

253 if L.time == 0.0: 

254 self.add_to_stats( 

255 process=step.status.slot, 

256 time=L.time, 

257 level=-1, 

258 iter=step.status.iter, 

259 sweep=L.status.sweep, 

260 type='computed_radius', 

261 value=radius, 

262 ) 

263 self.add_to_stats( 

264 process=step.status.slot, 

265 time=L.time, 

266 level=-1, 

267 iter=step.status.iter, 

268 sweep=L.status.sweep, 

269 type='exact_radius', 

270 value=exact_radius, 

271 ) 

272 self.add_to_stats( 

273 process=step.status.slot, 

274 time=L.time, 

275 level=-1, 

276 iter=step.status.iter, 

277 sweep=L.status.sweep, 

278 type='interface_width', 

279 value=interface_width, 

280 ) 

281 

282 def post_run(self, step, level_number): 

283 """ 

284 Args: 

285 step (pySDC.Step.step): the current step 

286 level_number (int): the current level number 

287 """ 

288 super().post_run(step, level_number) 

289 

290 L = step.levels[0] 

291 

292 exact_radius = self.exact_radius(L) 

293 radius, interface_width = self.compute_radius(L) 

294 

295 self.add_to_stats( 

296 process=step.status.slot, 

297 time=L.time + L.dt, 

298 level=-1, 

299 iter=step.status.iter, 

300 sweep=L.status.sweep, 

301 type='computed_radius', 

302 value=radius, 

303 ) 

304 self.add_to_stats( 

305 process=step.status.slot, 

306 time=L.time + L.dt, 

307 level=-1, 

308 iter=step.status.iter, 

309 sweep=L.status.sweep, 

310 type='exact_radius', 

311 value=exact_radius, 

312 ) 

313 self.add_to_stats( 

314 process=step.status.slot, 

315 time=L.time + L.dt, 

316 level=-1, 

317 iter=step.status.iter, 

318 sweep=L.status.sweep, 

319 type='interface_width', 

320 value=interface_width, 

321 ) 

322 self.add_to_stats( 

323 process=step.status.slot, 

324 time=L.time + L.dt, 

325 level=level_number, 

326 iter=step.status.iter, 

327 sweep=L.status.sweep, 

328 type='e_global_post_run', 

329 value=abs(radius - exact_radius), 

330 ) 

331 self.add_to_stats( 

332 process=step.status.slot, 

333 time=L.time + L.dt, 

334 level=level_number, 

335 iter=step.status.iter, 

336 sweep=L.status.sweep, 

337 type='e_global_rel_post_run', 

338 value=abs(radius - exact_radius) / abs(exact_radius), 

339 ) 

340 

341 

342if __name__ == '__main__': 

343 from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep 

344 

345 stats, _, _ = run_AC(imex=True, hook_class=LogLocalErrorPostStep) 

346 plot_solution(stats)