Coverage for pySDC/helpers/pySDC_as_gusto_time_discretization.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 10:09 +0000

1import firedrake as fd 

2 

3from gusto.time_discretisation.time_discretisation import TimeDiscretisation, wrapper_apply 

4from gusto.core.labels import explicit 

5 

6from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

7from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

8from pySDC.implementations.problem_classes.GenericGusto import GenericGusto, GenericGustoImex 

9from pySDC.core.hooks import Hooks 

10from pySDC.helpers.stats_helper import get_sorted 

11 

12import logging 

13import numpy as np 

14 

15 

16class LogTime(Hooks): 

17 """ 

18 Utility hook for knowing how far we got when using adaptive step size selection. 

19 """ 

20 

21 def post_step(self, step, level_number): 

22 L = step.levels[level_number] 

23 self.add_to_stats( 

24 process=step.status.slot, 

25 process_sweeper=L.sweep.rank, 

26 time=L.time, 

27 level=-1, 

28 iter=-1, 

29 sweep=-1, 

30 type='_time', 

31 value=L.time + L.dt, 

32 ) 

33 

34 

35class pySDC_integrator(TimeDiscretisation): 

36 """ 

37 This class can be entered into Gusto as a time discretization scheme and will solve steps using pySDC. 

38 It will construct a pySDC controller which can be used by itself and will be used within the time step when called 

39 from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`, 

40 which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto. 

41 

42 This class supports subcycling with multi-step SDC. You can use pseudo-parallelism by simply giving `n_steps` > 1 or 

43 do proper parallelism by giving a `controller_communicator` of kind `pySDC.FiredrakeEnsembleCommunicator` with the 

44 appropriate size. You also have to toggle between pseudo and proper parallelism with `useMPIController`. 

45 """ 

46 

47 def __init__( 

48 self, 

49 description, 

50 controller_params, 

51 domain, 

52 field_name=None, 

53 solver_parameters=None, 

54 options=None, 

55 imex=False, 

56 useMPIController=False, 

57 n_steps=1, 

58 controller_communicator=None, 

59 ): 

60 """ 

61 Initialization 

62 

63 Args: 

64 description (dict): pySDC description 

65 controller_params (dict): pySDC controller params 

66 domain (:class:`Domain`): the model's domain object, containing the 

67 mesh and the compatible function spaces. 

68 field_name (str, optional): name of the field to be evolved. 

69 Defaults to None. 

70 solver_parameters (dict, optional): dictionary of parameters to 

71 pass to the underlying solver. Defaults to None. 

72 options (:class:`AdvectionOptions`, optional): an object containing 

73 options to either be passed to the spatial discretisation, or 

74 to control the "wrapper" methods, such as Embedded DG or a 

75 recovery method. Defaults to None. 

76 imex (bool): Whether to use IMEX splitting 

77 useMPIController (bool): Whether to use the pseudo-parallel or proper parallel pySDC controller 

78 n_steps (int): Number of steps done in parallel when using pseudo-parallel pySDC controller 

79 controller_communicator (pySDC.FiredrakeEnsembleCommunicator, optional): Communicator for the proper parallel controller 

80 """ 

81 

82 self._residual = None 

83 

84 super().__init__( 

85 domain=domain, 

86 field_name=field_name, 

87 solver_parameters=solver_parameters, 

88 options=options, 

89 ) 

90 

91 self.description = description 

92 self.controller_params = controller_params 

93 self.timestepper = None 

94 self.dt_next = None 

95 self.imex = imex 

96 self.useMPIController = useMPIController 

97 self.controller_communicator = controller_communicator 

98 

99 if useMPIController: 

100 assert ( 

101 type(self.controller_communicator).__name__ == 'FiredrakeEnsembleCommunicator' 

102 ), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(self.controller_communicator)}' 

103 if n_steps > 1: 

104 logging.getLogger(type(self).__name__).warning( 

105 f'Warning: You selected {n_steps=}, which will be ignored when using the MPI controller!' 

106 ) 

107 assert ( 

108 controller_communicator is not None 

109 ), 'You need to supply a communicator when using the MPI controller!' 

110 self.n_steps = controller_communicator.size 

111 else: 

112 self.n_steps = n_steps 

113 

114 def setup(self, equation, apply_bcs=True, *active_labels): 

115 super().setup(equation, apply_bcs, *active_labels) 

116 

117 # Check if any terms are explicit 

118 imex = any(t.has_label(explicit) for t in equation.residual) or self.imex 

119 if imex: 

120 self.description['problem_class'] = GenericGustoImex 

121 else: 

122 self.description['problem_class'] = GenericGusto 

123 

124 self.description['problem_params'] = { 

125 'equation': equation, 

126 'solver_parameters': self.solver_parameters, 

127 'residual': self._residual, 

128 **self.description['problem_params'], 

129 } 

130 self.description['level_params']['dt'] = float(self.domain.dt) / self.n_steps 

131 

132 # add utility hook required for step size adaptivity 

133 hook_class = self.controller_params.get('hook_class', []) 

134 if not type(hook_class) == list: 

135 hook_class = [hook_class] 

136 hook_class.append(LogTime) 

137 self.controller_params['hook_class'] = hook_class 

138 

139 # prepare controller and variables 

140 if self.useMPIController: 

141 self.controller = controller_MPI( 

142 comm=self.controller_communicator, 

143 description=self.description, 

144 controller_params=self.controller_params, 

145 ) 

146 else: 

147 self.controller = controller_nonMPI( 

148 self.n_steps, description=self.description, controller_params=self.controller_params 

149 ) 

150 

151 self.prob = self.level.prob 

152 self.sweeper = self.level.sweep 

153 self.x0_pySDC = self.prob.dtype_u(self.prob.init) 

154 self.t = 0 

155 self.stats = {} 

156 

157 @property 

158 def residual(self): 

159 """Make sure the pySDC problem residual and this residual are the same""" 

160 if hasattr(self, 'prob'): 

161 return self.prob.residual 

162 else: 

163 return self._residual 

164 

165 @residual.setter 

166 def residual(self, value): 

167 """Make sure the pySDC problem residual and this residual are the same""" 

168 if hasattr(self, 'prob'): 

169 if self.useMPIController: 

170 self.controller.S.levels[0].prob.residual = value 

171 else: 

172 for S in self.controller.MS: 

173 S.levels[0].prob.residual = value 

174 else: 

175 self._residual = value 

176 

177 @property 

178 def step(self): 

179 """Get the first step on the controller""" 

180 if self.useMPIController: 

181 return self.controller.S 

182 else: 

183 return self.controller.MS[0] 

184 

185 @property 

186 def level(self): 

187 """Get the finest pySDC level""" 

188 return self.step.levels[0] 

189 

190 @wrapper_apply 

191 def apply(self, x_out, x_in): 

192 """ 

193 Apply the time discretization to advance one whole time step. 

194 

195 Args: 

196 x_out (:class:`Function`): the output field to be computed. 

197 x_in (:class:`Function`): the input field. 

198 """ 

199 self.x0_pySDC.functionspace.assign(x_in) 

200 assert np.isclose( 

201 self.level.params.dt * self.n_steps, float(self.dt) 

202 ), 'Step sizes have diverged between pySDC and Gusto' 

203 

204 if self.dt_next is not None: 

205 assert ( 

206 self.timestepper is not None 

207 ), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!' 

208 self.timestepper.dt = fd.Constant(self.dt_next * self.n_steps) 

209 self.t = self.timestepper.t 

210 

211 uend, _stats = self.controller.run(u0=self.x0_pySDC, t0=float(self.t), Tend=float(self.t + self.dt)) 

212 

213 # update time variables 

214 if not np.isclose(self.level.params.dt * self.n_steps, float(self.dt)): 

215 self.dt_next = self.level.params.dt 

216 

217 self.t = get_sorted(_stats, type='_time', recomputed=False, comm=self.controller_communicator)[-1][1] 

218 

219 # update time of the Gusto stepper. 

220 # After this step, the Gusto stepper updates its time again to arrive at the correct time 

221 if self.timestepper is not None: 

222 self.timestepper.t = fd.Constant(self.t - self.dt) 

223 

224 self.dt = fd.Constant(self.level.params.dt * self.n_steps) 

225 

226 # update stats and output 

227 self.stats = {**self.stats, **_stats} 

228 x_out.assign(uend.functionspace)