Coverage for pySDC/projects/Resilience/AC.py: 19%

162 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1# script to run an Allen-Cahn problem 

2from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex 

3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

4from pySDC.core.hooks import Hooks 

5from pySDC.projects.Resilience.hook import hook_collection, LogData 

6from pySDC.projects.Resilience.strategies import merge_descriptions 

7import matplotlib.pyplot as plt 

8import numpy as np 

9 

10from pySDC.core.errors import ConvergenceError 

11 

12 

13class allencahn_imex_timeforcing_adaptivity(allencahn2d_imex): 

14 r""" 

15 Add more source terms to `allencahn_imex_timeforcing` such that the time-scale changes and we can benefit from adaptivity. 

16 """ 

17 

18 def __init__(self, time_freq=2.0, time_dep_strength=1e-2, *args, **kwargs): 

19 super().__init__(*args, **kwargs) 

20 self._makeAttributeAndRegister('time_freq', 'time_dep_strength', localVars=locals(), readOnly=True) 

21 

22 def eval_f(self, u, t): 

23 f = super().eval_f(u, t) 

24 time_mod = self.get_time_dep_fac(self.time_freq, self.time_dep_strength, t) 

25 

26 if self.eps > 0: 

27 f.expl = -2.0 / self.eps**2 * u * (1.0 - u) * (1.0 - 2.0 * u) 

28 

29 # build sum over RHS without driving force 

30 Rt = float(np.sum(f.impl + f.expl)) 

31 

32 # build sum over driving force term 

33 Ht = float(np.sum(6.0 * u * (1.0 - u))) 

34 

35 # add/subtract time-dependent driving force 

36 if Ht != 0.0: 

37 dw = Rt / Ht * time_mod 

38 else: 

39 dw = 0.0 

40 

41 f.expl -= 6.0 * dw * u * (1.0 - u) 

42 

43 return f 

44 

45 @staticmethod 

46 def get_time_dep_fac(time_freq, time_dep_strength, t): 

47 return 1 - time_dep_strength * np.sin(time_freq * 2 * np.pi / 0.032 * t) 

48 

49 

50class monitor(Hooks): 

51 phase_thresh = 0.0 # count everything above this threshold to the high phase. 

52 

53 def __init__(self): 

54 """ 

55 Initialization of Allen-Cahn monitoring 

56 """ 

57 super().__init__() 

58 

59 self.init_radius = None 

60 

61 def get_exact_radius(self, t): 

62 return np.sqrt(max(self.init_radius**2 - 2.0 * t, 0)) 

63 

64 @classmethod 

65 def get_radius(cls, u, dx): 

66 c = np.count_nonzero(u > cls.phase_thresh) 

67 return np.sqrt(c / np.pi) * dx 

68 

69 @staticmethod 

70 def get_interface_width(u, L): 

71 # TODO: How does this generalize to different phase transitions? 

72 rows1 = np.where(u[L.prob.init[0][0] // 2, : L.prob.init[0][0] // 2] > -0.99) 

73 rows2 = np.where(u[L.prob.init[0][0] // 2, : L.prob.init[0][0] // 2] < 0.99) 

74 

75 return (rows2[0][-1] - rows1[0][0]) * L.prob.dx / L.prob.eps 

76 

77 def pre_run(self, step, level_number): 

78 """ 

79 Record radius of the blob, exact radius and interface width. 

80 

81 Args: 

82 step (pySDC.Step.step): the current step 

83 level_number (int): the current level number 

84 """ 

85 super().pre_run(step, level_number) 

86 L = step.levels[0] 

87 

88 radius = self.get_radius(L.u[0], L.prob.dx) 

89 self.init_radius = L.prob.radius 

90 

91 if L.time == 0.0: 

92 self.add_to_stats( 

93 process=step.status.slot, 

94 time=L.time, 

95 level=-1, 

96 iter=step.status.iter, 

97 sweep=L.status.sweep, 

98 type='computed_radius', 

99 value=radius, 

100 ) 

101 self.add_to_stats( 

102 process=step.status.slot, 

103 time=L.time, 

104 level=-1, 

105 iter=step.status.iter, 

106 sweep=L.status.sweep, 

107 type='exact_radius', 

108 value=self.init_radius, 

109 ) 

110 

111 def post_step(self, step, level_number): 

112 """ 

113 Record radius of the blob, exact radius and interface width. 

114 

115 Args: 

116 step (pySDC.Step.step): the current step 

117 level_number (int): the current level number 

118 """ 

119 super().post_step(step, level_number) 

120 

121 # some abbreviations 

122 L = step.levels[0] 

123 

124 radius = self.get_radius(L.uend, L.prob.dx) 

125 

126 exact_radius = self.get_exact_radius(L.time + L.dt) 

127 

128 self.add_to_stats( 

129 process=step.status.slot, 

130 time=L.time + L.dt, 

131 level=-1, 

132 iter=step.status.iter, 

133 sweep=L.status.sweep, 

134 type='computed_radius', 

135 value=radius, 

136 ) 

137 self.add_to_stats( 

138 process=step.status.slot, 

139 time=L.time + L.dt, 

140 level=-1, 

141 iter=step.status.iter, 

142 sweep=L.status.sweep, 

143 type='exact_radius', 

144 value=exact_radius, 

145 ) 

146 

147 

148def run_AC( 

149 custom_description=None, 

150 num_procs=1, 

151 Tend=1e-2, 

152 hook_class=LogData, 

153 fault_stuff=None, 

154 custom_controller_params=None, 

155 imex=False, 

156 u0=None, 

157 t0=None, 

158 use_MPI=False, 

159 live_plot=False, 

160 FFT=True, 

161 time_forcing=True, 

162 **kwargs, 

163): 

164 """ 

165 Args: 

166 custom_description (dict): Overwrite presets 

167 num_procs (int): Number of steps for MSSDC 

168 Tend (float): Time to integrate to 

169 hook_class (pySDC.Hook): A hook to store data 

170 fault_stuff (dict): A dictionary with information on how to add faults 

171 custom_controller_params (dict): Overwrite presets 

172 imex (bool): Solve the problem IMEX or fully implicit 

173 u0 (dtype_u): Initial value 

174 t0 (float): Starting time 

175 use_MPI (bool): Whether or not to use MPI 

176 

177 Returns: 

178 dict: The stats object 

179 controller: The controller 

180 bool: If the code crashed 

181 """ 

182 if custom_description is not None: 

183 problem_params = custom_description.get('problem_params', {}) 

184 if 'imex' in problem_params.keys(): 

185 imex = problem_params['imex'] 

186 problem_params.pop('imex', None) 

187 if 'FFT' in problem_params.keys(): 

188 FFT = problem_params['FFT'] 

189 problem_params.pop('FFT', None) 

190 

191 # import problem and sweeper class 

192 if time_forcing: 

193 problem_class = allencahn_imex_timeforcing_adaptivity 

194 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class 

195 elif FFT: 

196 from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex as problem_class 

197 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class 

198 elif imex: 

199 from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_semiimplicit as problem_class 

200 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class 

201 else: 

202 from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_fullyimplicit as problem_class 

203 from pySDC.projects.Resilience.sweepers import generic_implicit_efficient as sweeper_class 

204 

205 level_params = {} 

206 level_params['dt'] = 1e-4 

207 level_params['restol'] = 1e-8 

208 

209 sweeper_params = {} 

210 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

211 sweeper_params['num_nodes'] = 3 

212 sweeper_params['QI'] = 'LU' 

213 sweeper_params['QE'] = 'PIC' 

214 

215 # problem params 

216 fd_params = { 

217 'newton_tol': 1e-9, 

218 'order': 2, 

219 } 

220 problem_params = { 

221 'nvars': (128, 128), 

222 'init_type': 'circle', 

223 } 

224 if not FFT: 

225 problem_params = {**problem_params, **fd_params} 

226 

227 step_params = {} 

228 step_params['maxiter'] = 5 

229 

230 controller_params = {} 

231 controller_params['logger_level'] = 30 

232 controller_params['hook_class'] = ( 

233 hook_collection + (hook_class if type(hook_class) == list else [hook_class]) + ([LivePlot] if live_plot else []) 

234 ) 

235 controller_params['mssdc_jac'] = False 

236 

237 if custom_controller_params is not None: 

238 controller_params = {**controller_params, **custom_controller_params} 

239 

240 description = {} 

241 description['problem_class'] = problem_class 

242 description['problem_params'] = problem_params 

243 description['sweeper_class'] = sweeper_class 

244 description['sweeper_params'] = sweeper_params 

245 description['level_params'] = level_params 

246 description['step_params'] = step_params 

247 

248 if custom_description is not None: 

249 description = merge_descriptions(description, custom_description) 

250 

251 t0 = 0.0 if t0 is None else t0 

252 

253 controller_args = { 

254 'controller_params': controller_params, 

255 'description': description, 

256 } 

257 if use_MPI: 

258 from mpi4py import MPI 

259 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

260 

261 comm = kwargs.get('comm', MPI.COMM_WORLD) 

262 controller = controller_MPI(**controller_args, comm=comm) 

263 P = controller.S.levels[0].prob 

264 else: 

265 controller = controller_nonMPI(**controller_args, num_procs=num_procs) 

266 P = controller.MS[0].levels[0].prob 

267 

268 uinit = P.u_exact(t0) if u0 is None else u0 

269 

270 if fault_stuff is not None: 

271 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

272 

273 prepare_controller_for_faults(controller, fault_stuff) 

274 

275 crash = False 

276 try: 

277 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

278 except ConvergenceError as e: 

279 print(f'Warning: Premature termination!: {e}') 

280 stats = controller.return_stats() 

281 crash = True 

282 return stats, controller, crash 

283 

284 

285def plot_solution(stats): # pragma: no cover 

286 import matplotlib.pyplot as plt 

287 from pySDC.helpers.stats_helper import get_sorted 

288 

289 fig, ax = plt.subplots(1, 1) 

290 

291 u = get_sorted(stats, type='u', recomputed=False) 

292 for me in u: # pun intended 

293 ax.imshow(me[1], vmin=-1, vmax=1) 

294 ax.set_title(f't={me[0]:.2e}') 

295 plt.pause(1e-1) 

296 

297 plt.show() 

298 

299 

300class LivePlot(Hooks): # pragma: no cover 

301 def __init__(self): 

302 super().__init__() 

303 self.fig, self.axs = plt.subplots(1, 3, figsize=(12, 4)) 

304 self.radius = [] 

305 self.exact_radius = [] 

306 self.t = [] 

307 self.dt = [] 

308 

309 def post_step(self, step, level_number): 

310 super().post_step(step, level_number) 

311 L = step.levels[level_number] 

312 self.t += [step.time + step.dt] 

313 

314 # plot solution 

315 self.axs[0].cla() 

316 if len(L.uend.shape) > 1: 

317 self.axs[0].imshow(L.uend, vmin=0.0, vmax=1.0) 

318 

319 # plot radius 

320 self.axs[1].cla() 

321 radius, _ = LogRadius.compute_radius(step.levels[level_number]) 

322 exact_radius = LogRadius.exact_radius(step.levels[level_number]) 

323 

324 self.radius += [radius] 

325 self.exact_radius += [exact_radius] 

326 self.axs[1].plot(self.t, self.exact_radius, label='exact') 

327 self.axs[1].plot(self.t, self.radius, label='numerical') 

328 self.axs[1].set_ylim([0, 0.26]) 

329 self.axs[1].set_xlim([0, 0.03]) 

330 self.axs[1].legend(frameon=False) 

331 self.axs[1].set_title(r'Radius') 

332 else: 

333 self.axs[0].plot(L.prob.xvalues, L.prob.u_exact(t=L.time + L.dt), label='exact') 

334 self.axs[0].plot(L.prob.xvalues, L.uend, label='numerical') 

335 self.axs[0].set_title(f't = {step.time + step.dt:.2e}') 

336 

337 # plot step size 

338 self.axs[2].cla() 

339 self.dt += [step.dt] 

340 self.axs[2].plot(self.t, self.dt) 

341 self.axs[2].set_yscale('log') 

342 self.axs[2].axhline(step.levels[level_number].prob.eps ** 2, label=r'$\epsilon^2$', color='black', ls='--') 

343 self.axs[2].legend(frameon=False) 

344 self.axs[2].set_xlim([0, 0.03]) 

345 self.axs[2].set_title(r'$\Delta t$') 

346 

347 if step.status.restart: 

348 for me in [self.radius, self.exact_radius, self.t, self.dt]: 

349 try: 

350 me.pop(-1) 

351 except (TypeError, IndexError): 

352 pass 

353 

354 plt.pause(1e-9) 

355 

356 

357class LogRadius(Hooks): 

358 @staticmethod 

359 def compute_radius(L): 

360 c = np.count_nonzero(L.u[0] > 0.0) 

361 radius = np.sqrt(c / np.pi) * L.prob.dx 

362 

363 rows, cols = np.where(L.u[0] > 0.0) 

364 

365 rows1 = np.where(L.u[0][int((L.prob.init[0][0]) / 2), : int((L.prob.init[0][0]) / 2)] > -0.99) 

366 rows2 = np.where(L.u[0][int((L.prob.init[0][0]) / 2), : int((L.prob.init[0][0]) / 2)] < 0.99) 

367 interface_width = (rows2[0][-1] - rows1[0][0]) * L.prob.dx / L.prob.eps 

368 

369 return radius, interface_width 

370 

371 @staticmethod 

372 def exact_radius(L): 

373 init_radius = L.prob.radius 

374 return np.sqrt(max(init_radius**2 - 2.0 * (L.time + L.dt), 0)) 

375 

376 def pre_run(self, step, level_number): 

377 """ 

378 Overwrite standard pre run hook 

379 

380 Args: 

381 step (pySDC.Step.step): the current step 

382 level_number (int): the current level number 

383 """ 

384 super().pre_run(step, level_number) 

385 L = step.levels[0] 

386 

387 radius, interface_width = self.compute_radius(L) 

388 exact_radius = self.exact_radius(L) 

389 

390 if L.time == 0.0: 

391 self.add_to_stats( 

392 process=step.status.slot, 

393 time=L.time, 

394 level=-1, 

395 iter=step.status.iter, 

396 sweep=L.status.sweep, 

397 type='computed_radius', 

398 value=radius, 

399 ) 

400 self.add_to_stats( 

401 process=step.status.slot, 

402 time=L.time, 

403 level=-1, 

404 iter=step.status.iter, 

405 sweep=L.status.sweep, 

406 type='exact_radius', 

407 value=exact_radius, 

408 ) 

409 self.add_to_stats( 

410 process=step.status.slot, 

411 time=L.time, 

412 level=-1, 

413 iter=step.status.iter, 

414 sweep=L.status.sweep, 

415 type='interface_width', 

416 value=interface_width, 

417 ) 

418 

419 def post_run(self, step, level_number): 

420 """ 

421 Args: 

422 step (pySDC.Step.step): the current step 

423 level_number (int): the current level number 

424 """ 

425 super().post_run(step, level_number) 

426 

427 L = step.levels[0] 

428 

429 exact_radius = self.exact_radius(L) 

430 radius, interface_width = self.compute_radius(L) 

431 

432 self.add_to_stats( 

433 process=step.status.slot, 

434 time=L.time + L.dt, 

435 level=-1, 

436 iter=step.status.iter, 

437 sweep=L.status.sweep, 

438 type='computed_radius', 

439 value=radius, 

440 ) 

441 self.add_to_stats( 

442 process=step.status.slot, 

443 time=L.time + L.dt, 

444 level=-1, 

445 iter=step.status.iter, 

446 sweep=L.status.sweep, 

447 type='exact_radius', 

448 value=exact_radius, 

449 ) 

450 self.add_to_stats( 

451 process=step.status.slot, 

452 time=L.time + L.dt, 

453 level=-1, 

454 iter=step.status.iter, 

455 sweep=L.status.sweep, 

456 type='interface_width', 

457 value=interface_width, 

458 ) 

459 self.add_to_stats( 

460 process=step.status.slot, 

461 time=L.time + L.dt, 

462 level=level_number, 

463 iter=step.status.iter, 

464 sweep=L.status.sweep, 

465 type='e_global_post_run', 

466 value=abs(radius - exact_radius), 

467 ) 

468 self.add_to_stats( 

469 process=step.status.slot, 

470 time=L.time + L.dt, 

471 level=level_number, 

472 iter=step.status.iter, 

473 sweep=L.status.sweep, 

474 type='e_global_rel_post_run', 

475 value=abs(radius - exact_radius) / abs(exact_radius), 

476 ) 

477 

478 

479if __name__ == '__main__': 

480 

481 stats, _, _ = run_AC() 

482 plot_solution(stats)