Coverage for pySDC/projects/Resilience/Schroedinger.py: 87%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import numpy as np 

2 

3from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

4from pySDC.implementations.problem_classes.NonlinearSchroedinger_MPIFFT import ( 

5 nonlinearschroedinger_imex, 

6 nonlinearschroedinger_fully_implicit, 

7) 

8from pySDC.projects.Resilience.hook import LogData, hook_collection 

9from pySDC.projects.Resilience.strategies import merge_descriptions 

10from pySDC.projects.Resilience.sweepers import imex_1st_order_efficient, generic_implicit_efficient 

11from pySDC.core.errors import ConvergenceError 

12 

13from pySDC.core.hooks import Hooks 

14 

15import matplotlib.pyplot as plt 

16from mpl_toolkits.axes_grid1 import make_axes_locatable 

17 

18 

19class live_plotting_with_error(Hooks): # pragma: no cover 

20 def __init__(self): 

21 super().__init__() 

22 self.fig, self.axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(12, 7)) 

23 

24 divider = make_axes_locatable(self.axs[1]) 

25 self.cax_right = divider.append_axes('right', size='5%', pad=0.05) 

26 divider = make_axes_locatable(self.axs[0]) 

27 self.cax_left = divider.append_axes('right', size='5%', pad=0.05) 

28 

29 def post_step(self, step, level_number): 

30 lvl = step.levels[level_number] 

31 lvl.sweep.compute_end_point() 

32 

33 self.axs[0].cla() 

34 im1 = self.axs[0].imshow(np.abs(lvl.uend), vmin=0, vmax=2.0) 

35 self.fig.colorbar(im1, cax=self.cax_left) 

36 

37 self.axs[1].cla() 

38 im = self.axs[1].imshow(np.abs(lvl.prob.u_exact(lvl.time + lvl.dt) - lvl.uend)) 

39 self.fig.colorbar(im, cax=self.cax_right) 

40 

41 self.fig.suptitle(f't={lvl.time:.2f}') 

42 self.axs[0].set_title('solution') 

43 self.axs[1].set_title('error') 

44 plt.pause(1e-9) 

45 

46 

47class live_plotting(Hooks): # pragma: no cover 

48 def __init__(self): 

49 super().__init__() 

50 self.fig, self.ax = plt.subplots() 

51 divider = make_axes_locatable(self.ax) 

52 self.cax = divider.append_axes('right', size='5%', pad=0.05) 

53 

54 def post_step(self, step, level_number): 

55 lvl = step.levels[level_number] 

56 lvl.sweep.compute_end_point() 

57 

58 self.ax.cla() 

59 im = self.ax.imshow(np.abs(lvl.uend), vmin=0.2, vmax=1.8) 

60 self.ax.set_title(f't={lvl.time + lvl.dt:.2f}') 

61 self.fig.colorbar(im, cax=self.cax) 

62 plt.pause(1e-9) 

63 

64 

65def run_Schroedinger( 

66 custom_description=None, 

67 num_procs=1, 

68 Tend=1.0, 

69 hook_class=LogData, 

70 fault_stuff=None, 

71 custom_controller_params=None, 

72 use_MPI=False, 

73 space_comm=None, 

74 imex=True, 

75 **kwargs, 

76): 

77 """ 

78 Run a Schroedinger problem with default parameters. 

79 

80 Args: 

81 custom_description (dict): Overwrite presets 

82 num_procs (int): Number of steps for MSSDC 

83 Tend (float): Time to integrate to 

84 hook_class (pySDC.Hook): A hook to store data 

85 fault_stuff (dict): A dictionary with information on how to add faults 

86 custom_controller_params (dict): Overwrite presets 

87 use_MPI (bool): Whether or not to use MPI 

88 space_comm (mpi4py.Intracomm): Space communicator 

89 imex (bool): Whether to use IMEX implementation or the fully implicit one 

90 

91 Returns: 

92 dict: The stats object 

93 controller: The controller 

94 bool: If the code crashed 

95 """ 

96 if custom_description is not None: 

97 problem_params = custom_description.get('problem_params', {}) 

98 if 'imex' in problem_params.keys(): 

99 imex = problem_params['imex'] 

100 problem_params.pop('imex', None) 

101 

102 from mpi4py import MPI 

103 

104 space_comm = MPI.COMM_SELF if space_comm is None else space_comm 

105 rank = space_comm.Get_rank() 

106 

107 # initialize level parameters 

108 level_params = dict() 

109 level_params['restol'] = 1e-8 

110 level_params['dt'] = 1e-01 / 2 

111 level_params['nsweeps'] = 1 

112 

113 # initialize sweeper parameters 

114 sweeper_params = dict() 

115 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

116 sweeper_params['num_nodes'] = 3 

117 sweeper_params['QI'] = 'IE' 

118 sweeper_params['QE'] = 'PIC' 

119 sweeper_params['initial_guess'] = 'spread' 

120 

121 # initialize problem parameters 

122 problem_params = dict() 

123 problem_params['nvars'] = (128, 128) 

124 problem_params['spectral'] = False 

125 problem_params['c'] = 1.0 

126 problem_params['comm'] = space_comm 

127 if not imex: 

128 problem_params['liniter'] = 99 

129 problem_params['lintol'] = 1e-8 

130 

131 # initialize step parameters 

132 step_params = dict() 

133 step_params['maxiter'] = 50 

134 

135 # initialize controller parameters 

136 controller_params = dict() 

137 controller_params['logger_level'] = 15 if rank == 0 else 99 

138 controller_params['hook_class'] = hook_collection + (hook_class if type(hook_class) == list else [hook_class]) 

139 controller_params['mssdc_jac'] = False 

140 

141 # fill description dictionary for easy step instantiation 

142 if custom_controller_params is not None: 

143 controller_params = {**controller_params, **custom_controller_params} 

144 

145 description = dict() 

146 description['problem_params'] = problem_params 

147 description['problem_class'] = nonlinearschroedinger_imex if imex else nonlinearschroedinger_fully_implicit 

148 description['sweeper_class'] = imex_1st_order_efficient if imex else generic_implicit_efficient 

149 description['sweeper_params'] = sweeper_params 

150 description['level_params'] = level_params 

151 description['step_params'] = step_params 

152 

153 if custom_description is not None: 

154 description = merge_descriptions(description, custom_description) 

155 

156 # set time parameters 

157 t0 = 0.0 

158 

159 # instantiate controller 

160 controller_args = { 

161 'controller_params': controller_params, 

162 'description': description, 

163 } 

164 if use_MPI: 

165 from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

166 

167 comm = kwargs.get('comm', MPI.COMM_WORLD) 

168 controller = controller_MPI(**controller_args, comm=comm) 

169 P = controller.S.levels[0].prob 

170 else: 

171 controller = controller_nonMPI(**controller_args, num_procs=num_procs) 

172 P = controller.MS[0].levels[0].prob 

173 

174 uinit = P.u_exact(t0) 

175 

176 # insert faults 

177 if fault_stuff is not None: 

178 from pySDC.projects.Resilience.fault_injection import prepare_controller_for_faults 

179 

180 nvars = [me / 2 for me in problem_params['nvars']] 

181 nvars[0] += 1 

182 

183 rnd_args = {'problem_pos': nvars} 

184 prepare_controller_for_faults(controller, fault_stuff, rnd_args) 

185 

186 # call main function to get things done... 

187 crash = False 

188 try: 

189 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

190 except (ConvergenceError, OverflowError) as e: 

191 print(f'Warning: Premature termination: {e}') 

192 stats = controller.return_stats() 

193 crash = True 

194 

195 return stats, controller, crash 

196 

197 

198def main(): 

199 from mpi4py import MPI 

200 

201 stats, _, _ = run_Schroedinger(space_comm=MPI.COMM_WORLD, hook_class=live_plotting, imex=False) 

202 plt.show() 

203 

204 

205if __name__ == "__main__": 

206 main()