Coverage for pySDC/projects/Resilience/Schroedinger.py: 87%
82 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import numpy as np
3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
4from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import (
5 nonlinearschroedinger_imex,
6 nonlinearschroedinger_fully_implicit,
7)
8from pySDC.projects.Resilience.hook import LogData, hook_collection
9from pySDC.projects.Resilience.strategies import merge_descriptions
10from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient
11from pySDC.core.errors import ConvergenceError
13from pySDC.core.hooks import Hooks
15import matplotlib.pyplot as plt
16from mpl_toolkits.axes_grid1 import make_axes_locatable
19class live_plotting_with_error(Hooks): # pragma: no cover
20 def __init__(self):
21 super().__init__()
22 self.fig, self.axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(12, 7))
24 divider = make_axes_locatable(self.axs[1])
25 self.cax_right = divider.append_axes('right', size='5%', pad=0.05)
26 divider = make_axes_locatable(self.axs[0])
27 self.cax_left = divider.append_axes('right', size='5%', pad=0.05)
29 def post_step(self, step, level_number):
30 lvl = step.levels[level_number]
31 lvl.sweep.compute_end_point()
33 self.axs[0].cla()
34 im1 = self.axs[0].imshow(np.abs(lvl.uend), vmin=0, vmax=2.0)
35 self.fig.colorbar(im1, cax=self.cax_left)
37 self.axs[1].cla()
38 im = self.axs[1].imshow(np.abs(lvl.prob.u_exact(lvl.time + lvl.dt) - lvl.uend))
39 self.fig.colorbar(im, cax=self.cax_right)
41 self.fig.suptitle(f't={lvl.time:.2f}')
42 self.axs[0].set_title('solution')
43 self.axs[1].set_title('error')
44 plt.pause(1e-9)
47class live_plotting(Hooks): # pragma: no cover
48 def __init__(self):
49 super().__init__()
50 self.fig, self.ax = plt.subplots()
51 divider = make_axes_locatable(self.ax)
52 self.cax = divider.append_axes('right', size='5%', pad=0.05)
54 def post_step(self, step, level_number):
55 lvl = step.levels[level_number]
56 lvl.sweep.compute_end_point()
58 self.ax.cla()
59 im = self.ax.imshow(np.abs(lvl.uend), vmin=0.2, vmax=1.8)
60 self.ax.set_title(f't={lvl.time + lvl.dt:.2f}')
61 self.fig.colorbar(im, cax=self.cax)
62 plt.pause(1e-9)
65def run_Schroedinger(
66 custom_description=None,
67 num_procs=1,
68 Tend=1.0,
69 hook_class=LogData,
70 fault_stuff=None,
71 custom_controller_params=None,
72 use_MPI=False,
73 space_comm=None,
74 imex=True,
75 **kwargs,
76):
77 """
78 Run a Schroedinger problem with default parameters.
80 Args:
81 custom_description (dict): Overwrite presets
82 num_procs (int): Number of steps for MSSDC
83 Tend (float): Time to integrate to
84 hook_class (pySDC.Hook): A hook to store data
85 fault_stuff (dict): A dictionary with information on how to add faults
86 custom_controller_params (dict): Overwrite presets
87 use_MPI (bool): Whether or not to use MPI
88 space_comm (mpi4py.Intracomm): Space communicator
89 imex (bool): Whether to use IMEX implementation or the fully implicit one
91 Returns:
92 dict: The stats object
93 controller: The controller
94 bool: If the code crashed
95 """
96 if custom_description is not None:
97 problem_params = custom_description.get('problem_params', {})
98 if 'imex' in problem_params.keys():
99 imex = problem_params['imex']
100 problem_params.pop('imex', None)
102 from mpi4py import MPI
104 space_comm = MPI.COMM_SELF if space_comm is None else space_comm
105 rank = space_comm.Get_rank()
107 # initialize level parameters
108 level_params = dict()
109 level_params['restol'] = 1e-8
110 level_params['dt'] = 1e-01 / 2
111 level_params['nsweeps'] = 1
113 # initialize sweeper parameters
114 sweeper_params = dict()
115 sweeper_params['quad_type'] = 'RADAU-RIGHT'
116 sweeper_params['num_nodes'] = 3
117 sweeper_params['QI'] = 'IE'
118 sweeper_params['QE'] = 'PIC'
119 sweeper_params['initial_guess'] = 'spread'
121 # initialize problem parameters
122 problem_params = dict()
123 problem_params['nvars'] = (128, 128)
124 problem_params['spectral'] = False
125 problem_params['c'] = 1.0
126 problem_params['comm'] = space_comm
127 if not imex:
128 problem_params['liniter'] = 99
129 problem_params['lintol'] = 1e-8
131 # initialize step parameters
132 step_params = dict()
133 step_params['maxiter'] = 50
135 # initialize controller parameters
136 controller_params = dict()
137 controller_params['logger_level'] = 15 if rank == 0 else 99
138 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class])
139 controller_params['mssdc_jac'] = False
141 # fill description dictionary for easy step instantiation
142 if custom_controller_params is not None:
143 controller_params = {**controller_params, **custom_controller_params}
145 description = dict()
146 description['problem_params'] = problem_params
147 description['problem_class'] = nonlinearschroedinger_imex if imex else nonlinearschroedinger_fully_implicit
148 description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient
149 description['sweeper_params'] = sweeper_params
150 description['level_params'] = level_params
151 description['step_params'] = step_params
153 if custom_description is not None:
154 description = merge_descriptions(description, custom_description)
156 # set time parameters
157 t0 = 0.0
159 # instantiate controller
160 controller_args = {
161 'controller_params': controller_params,
162 'description': description,
163 }
164 if use_MPI:
165 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
167 comm = kwargs.get('comm', MPI.COMM_WORLD)
168 controller = controller_MPI(**controller_args, comm=comm)
169 P = controller.S.levels[0].prob
170 else:
171 controller = controller_nonMPI(**controller_args, num_procs=num_procs)
172 P = controller.MS[0].levels[0].prob
174 uinit = P.u_exact(t0)
176 # insert faults
177 if fault_stuff is not None:
178 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults
180 nvars = [me / 2 for me in problem_params['nvars']]
181 nvars[0] += 1
183 rnd_args = {'problem_pos': nvars}
184 prepare_controller_for_faults(controller, fault_stuff, rnd_args)
186 # call main function to get things done...
187 crash = False
188 try:
189 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
190 except (ConvergenceError, OverflowError) as e:
191 print(f'Warning: Premature termination: {e}')
192 stats = controller.return_stats()
193 crash = True
195 return stats, controller, crash
198def main():
199 from mpi4py import MPI
201 stats, _, _ = run_Schroedinger(space_comm=MPI.COMM_WORLD, hook_class=live_plotting, imex=False)
202 plt.show()
205if __name__ == "__main__":
206 main()