Coverage for pySDC/projects/Resilience/AC.py: 16%
111 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1# script to run an Allen-Cahn problem
2from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_fullyimplicit, allencahn_semiimplicit
3from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex
4from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
5from pySDC.core.Hooks import hooks
6from pySDC.projects.Resilience.hook import hook_collection, LogData
7from pySDC.projects.Resilience.strategies import merge_descriptions
8from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient
9import matplotlib.pyplot as plt
10import numpy as np
12from pySDC.core.Errors import ConvergenceError
15def run_AC(
16 custom_description=None,
17 num_procs=1,
18 Tend=1e-2,
19 hook_class=LogData,
20 fault_stuff=None,
21 custom_controller_params=None,
22 imex=False,
23 u0=None,
24 t0=None,
25 use_MPI=False,
26 live_plot=False,
27 FFT=True,
28 **kwargs,
29):
30 """
31 Args:
32 custom_description (dict): Overwrite presets
33 num_procs (int): Number of steps for MSSDC
34 Tend (float): Time to integrate to
35 hook_class (pySDC.Hook): A hook to store data
36 fault_stuff (dict): A dictionary with information on how to add faults
37 custom_controller_params (dict): Overwrite presets
38 imex (bool): Solve the problem IMEX or fully implicit
39 u0 (dtype_u): Initial value
40 t0 (float): Starting time
41 use_MPI (bool): Whether or not to use MPI
43 Returns:
44 dict: The stats object
45 controller: The controller
46 bool: If the code crashed
47 """
48 if custom_description is not None:
49 problem_params = custom_description.get('problem_params', {})
50 if 'imex' in problem_params.keys():
51 imex = problem_params['imex']
52 problem_params.pop('imex', None)
53 if 'FFT' in problem_params.keys():
54 FFT = problem_params['FFT']
55 problem_params.pop('FFT', None)
57 # import problem and sweeper class
58 if FFT:
59 from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex as problem_class
60 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class
61 elif imex:
62 from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_semiimplicit as problem_class
63 from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient as sweeper_class
64 else:
65 from pySDC.implementations.problem_classes.AllenCahn_2D_FD import allencahn_fullyimplicit as problem_class
66 from pySDC.projects.Resilience.sweepers import generic_implicit_efficient as sweeper_class
68 level_params = {}
69 level_params['dt'] = 1e-4
70 level_params['restol'] = 1e-8
72 sweeper_params = {}
73 sweeper_params['quad_type'] = 'RADAU-RIGHT'
74 sweeper_params['num_nodes'] = 3
75 sweeper_params['QI'] = 'LU'
76 sweeper_params['QE'] = 'PIC'
78 # problem params
79 fd_params = {
80 'newton_tol': 1e-9,
81 'order': 2,
82 }
83 problem_params = {
84 'nvars': (128, 128),
85 'init_type': 'circle',
86 }
87 if not FFT:
88 problem_params = {**problem_params, **fd_params}
90 step_params = {}
91 step_params['maxiter'] = 5
93 controller_params = {}
94 controller_params['logger_level'] = 30
95 controller_params['hook_class'] = (
96 hook_collection + (hook_class if type(hook_class) == list else [hook_class]) + ([LivePlot] if live_plot else [])
97 )
98 controller_params['mssdc_jac'] = False
100 if custom_controller_params is not None:
101 controller_params = {**controller_params, **custom_controller_params}
103 description = {}
104 description['problem_class'] = problem_class
105 description['problem_params'] = problem_params
106 description['sweeper_class'] = sweeper_class
107 description['sweeper_params'] = sweeper_params
108 description['level_params'] = level_params
109 description['step_params'] = step_params
111 if custom_description is not None:
112 description = merge_descriptions(description, custom_description)
114 t0 = 0.0 if t0 is None else t0
116 controller_args = {
117 'controller_params': controller_params,
118 'description': description,
119 }
120 if use_MPI:
121 from mpi4py import MPI
122 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
124 comm = kwargs.get('comm', MPI.COMM_WORLD)
125 controller = controller_MPI(**controller_args, comm=comm)
126 P = controller.S.levels[0].prob
127 else:
128 controller = controller_nonMPI(**controller_args, num_procs=num_procs)
129 P = controller.MS[0].levels[0].prob
131 uinit = P.u_exact(t0) if u0 is None else u0
133 if fault_stuff is not None:
134 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
136 prepare_controller_for_faults(controller, fault_stuff)
138 crash = False
139 try:
140 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
141 except ConvergenceError as e:
142 print(f'Warning: Premature termination!: {e}')
143 stats = controller.return_stats()
144 crash = True
145 return stats, controller, crash
148def plot_solution(stats): # pragma: no cover
149 import matplotlib.pyplot as plt
150 from pySDC.helpers.stats_helper import get_sorted
152 fig, ax = plt.subplots(1, 1)
154 u = get_sorted(stats, type='u', recomputed=False)
155 for me in u: # pun intended
156 ax.imshow(me[1], vmin=-1, vmax=1)
157 ax.set_title(f't={me[0]:.2e}')
158 plt.pause(1e-1)
160 plt.show()
163class LivePlot(hooks): # pragma: no cover
164 def __init__(self):
165 super().__init__()
166 self.fig, self.axs = plt.subplots(1, 3, figsize=(12, 4))
167 self.radius = []
168 self.exact_radius = []
169 self.t = []
170 self.dt = []
172 def post_step(self, step, level_number):
173 super().post_step(step, level_number)
174 L = step.levels[level_number]
175 self.t += [step.time + step.dt]
177 # plot solution
178 self.axs[0].cla()
179 if len(L.uend.shape) > 1:
180 self.axs[0].imshow(L.uend, vmin=0.0, vmax=1.0)
182 # plot radius
183 self.axs[1].cla()
184 radius, _ = LogRadius.compute_radius(step.levels[level_number])
185 exact_radius = LogRadius.exact_radius(step.levels[level_number])
187 self.radius += [radius]
188 self.exact_radius += [exact_radius]
189 self.axs[1].plot(self.t, self.exact_radius, label='exact')
190 self.axs[1].plot(self.t, self.radius, label='numerical')
191 self.axs[1].set_ylim([0, 0.26])
192 self.axs[1].set_xlim([0, 0.03])
193 self.axs[1].legend(frameon=False)
194 self.axs[1].set_title(r'Radius')
195 else:
196 self.axs[0].plot(L.prob.xvalues, L.prob.u_exact(t=L.time + L.dt), label='exact')
197 self.axs[0].plot(L.prob.xvalues, L.uend, label='numerical')
198 self.axs[0].set_title(f't = {step.time + step.dt:.2e}')
200 # plot step size
201 self.axs[2].cla()
202 self.dt += [step.dt]
203 self.axs[2].plot(self.t, self.dt)
204 self.axs[2].set_yscale('log')
205 self.axs[2].axhline(step.levels[level_number].prob.eps ** 2, label=r'$\epsilon^2$', color='black', ls='--')
206 self.axs[2].legend(frameon=False)
207 self.axs[2].set_xlim([0, 0.03])
208 self.axs[2].set_title(r'$\Delta t$')
210 if step.status.restart:
211 for me in [self.radius, self.exact_radius, self.t, self.dt]:
212 try:
213 me.pop(-1)
214 except (TypeError, IndexError):
215 pass
217 plt.pause(1e-9)
220class LogRadius(hooks):
221 @staticmethod
222 def compute_radius(L):
223 c = np.count_nonzero(L.u[0] > 0.0)
224 radius = np.sqrt(c / np.pi) * L.prob.dx
226 rows, cols = np.where(L.u[0] > 0.0)
228 rows1 = np.where(L.u[0][int((L.prob.init[0][0]) / 2), : int((L.prob.init[0][0]) / 2)] > -0.99)
229 rows2 = np.where(L.u[0][int((L.prob.init[0][0]) / 2), : int((L.prob.init[0][0]) / 2)] < 0.99)
230 interface_width = (rows2[0][-1] - rows1[0][0]) * L.prob.dx / L.prob.eps
232 return radius, interface_width
234 @staticmethod
235 def exact_radius(L):
236 init_radius = L.prob.radius
237 return np.sqrt(max(init_radius**2 - 2.0 * (L.time + L.dt), 0))
239 def pre_run(self, step, level_number):
240 """
241 Overwrite standard pre run hook
243 Args:
244 step (pySDC.Step.step): the current step
245 level_number (int): the current level number
246 """
247 super().pre_run(step, level_number)
248 L = step.levels[0]
250 radius, interface_width = self.compute_radius(L)
251 exact_radius = self.exact_radius(L)
253 if L.time == 0.0:
254 self.add_to_stats(
255 process=step.status.slot,
256 time=L.time,
257 level=-1,
258 iter=step.status.iter,
259 sweep=L.status.sweep,
260 type='computed_radius',
261 value=radius,
262 )
263 self.add_to_stats(
264 process=step.status.slot,
265 time=L.time,
266 level=-1,
267 iter=step.status.iter,
268 sweep=L.status.sweep,
269 type='exact_radius',
270 value=exact_radius,
271 )
272 self.add_to_stats(
273 process=step.status.slot,
274 time=L.time,
275 level=-1,
276 iter=step.status.iter,
277 sweep=L.status.sweep,
278 type='interface_width',
279 value=interface_width,
280 )
282 def post_run(self, step, level_number):
283 """
284 Args:
285 step (pySDC.Step.step): the current step
286 level_number (int): the current level number
287 """
288 super().post_run(step, level_number)
290 L = step.levels[0]
292 exact_radius = self.exact_radius(L)
293 radius, interface_width = self.compute_radius(L)
295 self.add_to_stats(
296 process=step.status.slot,
297 time=L.time + L.dt,
298 level=-1,
299 iter=step.status.iter,
300 sweep=L.status.sweep,
301 type='computed_radius',
302 value=radius,
303 )
304 self.add_to_stats(
305 process=step.status.slot,
306 time=L.time + L.dt,
307 level=-1,
308 iter=step.status.iter,
309 sweep=L.status.sweep,
310 type='exact_radius',
311 value=exact_radius,
312 )
313 self.add_to_stats(
314 process=step.status.slot,
315 time=L.time + L.dt,
316 level=-1,
317 iter=step.status.iter,
318 sweep=L.status.sweep,
319 type='interface_width',
320 value=interface_width,
321 )
322 self.add_to_stats(
323 process=step.status.slot,
324 time=L.time + L.dt,
325 level=level_number,
326 iter=step.status.iter,
327 sweep=L.status.sweep,
328 type='e_global_post_run',
329 value=abs(radius - exact_radius),
330 )
331 self.add_to_stats(
332 process=step.status.slot,
333 time=L.time + L.dt,
334 level=level_number,
335 iter=step.status.iter,
336 sweep=L.status.sweep,
337 type='e_global_rel_post_run',
338 value=abs(radius - exact_radius) / abs(exact_radius),
339 )
342if __name__ == '__main__':
343 from pySDC.implementations.hooks.log_errors import LogLocalErrorPostStep
345 stats, _, _ = run_AC(imex=True, hook_class=LogLocalErrorPostStep)
346 plot_solution(stats)