Coverage for pySDC/helpers/pySDC_as_gusto_time_discretization.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
1import firedrake as fd
3from gusto.time_discretisation.time_discretisation import TimeDiscretisation, wrapper_apply
4from gusto.core.labels import explicit
6from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
7from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
8from pySDC.implementations.problem_classes.GenericGusto import GenericGusto, GenericGustoImex
9from pySDC.core.hooks import Hooks
10from pySDC.helpers.stats_helper import get_sorted
12import logging
13import numpy as np
16class LogTime(Hooks):
17 """
18 Utility hook for knowing how far we got when using adaptive step size selection.
19 """
21 def post_step(self, step, level_number):
22 L = step.levels[level_number]
23 self.add_to_stats(
24 process=step.status.slot,
25 process_sweeper=L.sweep.rank,
26 time=L.time,
27 level=-1,
28 iter=-1,
29 sweep=-1,
30 type='_time',
31 value=L.time + L.dt,
32 )
35class pySDC_integrator(TimeDiscretisation):
36 """
37 This class can be entered into Gusto as a time discretization scheme and will solve steps using pySDC.
38 It will construct a pySDC controller which can be used by itself and will be used within the time step when called
39 from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`,
40 which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto.
42 This class supports subcycling with multi-step SDC. You can use pseudo-parallelism by simply giving `n_steps` > 1 or
43 do proper parallelism by giving a `controller_communicator` of kind `pySDC.FiredrakeEnsembleCommunicator` with the
44 appropriate size. You also have to toggle between pseudo and proper parallelism with `useMPIController`.
45 """
47 def __init__(
48 self,
49 description,
50 controller_params,
51 domain,
52 field_name=None,
53 solver_parameters=None,
54 options=None,
55 imex=False,
56 useMPIController=False,
57 n_steps=1,
58 controller_communicator=None,
59 ):
60 """
61 Initialization
63 Args:
64 description (dict): pySDC description
65 controller_params (dict): pySDC controller params
66 domain (:class:`Domain`): the model's domain object, containing the
67 mesh and the compatible function spaces.
68 field_name (str, optional): name of the field to be evolved.
69 Defaults to None.
70 solver_parameters (dict, optional): dictionary of parameters to
71 pass to the underlying solver. Defaults to None.
72 options (:class:`AdvectionOptions`, optional): an object containing
73 options to either be passed to the spatial discretisation, or
74 to control the "wrapper" methods, such as Embedded DG or a
75 recovery method. Defaults to None.
76 imex (bool): Whether to use IMEX splitting
77 useMPIController (bool): Whether to use the pseudo-parallel or proper parallel pySDC controller
78 n_steps (int): Number of steps done in parallel when using pseudo-parallel pySDC controller
79 controller_communicator (pySDC.FiredrakeEnsembleCommunicator, optional): Communicator for the proper parallel controller
80 """
82 self._residual = None
84 super().__init__(
85 domain=domain,
86 field_name=field_name,
87 solver_parameters=solver_parameters,
88 options=options,
89 )
91 self.description = description
92 self.controller_params = controller_params
93 self.timestepper = None
94 self.dt_next = None
95 self.imex = imex
96 self.useMPIController = useMPIController
97 self.controller_communicator = controller_communicator
99 if useMPIController:
100 assert (
101 type(self.controller_communicator).__name__ == 'FiredrakeEnsembleCommunicator'
102 ), f'Need to give a FiredrakeEnsembleCommunicator here, not {type(self.controller_communicator)}'
103 if n_steps > 1:
104 logging.getLogger(type(self).__name__).warning(
105 f'Warning: You selected {n_steps=}, which will be ignored when using the MPI controller!'
106 )
107 assert (
108 controller_communicator is not None
109 ), 'You need to supply a communicator when using the MPI controller!'
110 self.n_steps = controller_communicator.size
111 else:
112 self.n_steps = n_steps
114 def setup(self, equation, apply_bcs=True, *active_labels):
115 super().setup(equation, apply_bcs, *active_labels)
117 # Check if any terms are explicit
118 imex = any(t.has_label(explicit) for t in equation.residual) or self.imex
119 if imex:
120 self.description['problem_class'] = GenericGustoImex
121 else:
122 self.description['problem_class'] = GenericGusto
124 self.description['problem_params'] = {
125 'equation': equation,
126 'solver_parameters': self.solver_parameters,
127 'residual': self._residual,
128 **self.description['problem_params'],
129 }
130 self.description['level_params']['dt'] = float(self.domain.dt) / self.n_steps
132 # add utility hook required for step size adaptivity
133 hook_class = self.controller_params.get('hook_class', [])
134 if not type(hook_class) == list:
135 hook_class = [hook_class]
136 hook_class.append(LogTime)
137 self.controller_params['hook_class'] = hook_class
139 # prepare controller and variables
140 if self.useMPIController:
141 self.controller = controller_MPI(
142 comm=self.controller_communicator,
143 description=self.description,
144 controller_params=self.controller_params,
145 )
146 else:
147 self.controller = controller_nonMPI(
148 self.n_steps, description=self.description, controller_params=self.controller_params
149 )
151 self.prob = self.level.prob
152 self.sweeper = self.level.sweep
153 self.x0_pySDC = self.prob.dtype_u(self.prob.init)
154 self.t = 0
155 self.stats = {}
157 @property
158 def residual(self):
159 """Make sure the pySDC problem residual and this residual are the same"""
160 if hasattr(self, 'prob'):
161 return self.prob.residual
162 else:
163 return self._residual
165 @residual.setter
166 def residual(self, value):
167 """Make sure the pySDC problem residual and this residual are the same"""
168 if hasattr(self, 'prob'):
169 if self.useMPIController:
170 self.controller.S.levels[0].prob.residual = value
171 else:
172 for S in self.controller.MS:
173 S.levels[0].prob.residual = value
174 else:
175 self._residual = value
177 @property
178 def step(self):
179 """Get the first step on the controller"""
180 if self.useMPIController:
181 return self.controller.S
182 else:
183 return self.controller.MS[0]
185 @property
186 def level(self):
187 """Get the finest pySDC level"""
188 return self.step.levels[0]
190 @wrapper_apply
191 def apply(self, x_out, x_in):
192 """
193 Apply the time discretization to advance one whole time step.
195 Args:
196 x_out (:class:`Function`): the output field to be computed.
197 x_in (:class:`Function`): the input field.
198 """
199 self.x0_pySDC.functionspace.assign(x_in)
200 assert np.isclose(
201 self.level.params.dt * self.n_steps, float(self.dt)
202 ), 'Step sizes have diverged between pySDC and Gusto'
204 if self.dt_next is not None:
205 assert (
206 self.timestepper is not None
207 ), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!'
208 self.timestepper.dt = fd.Constant(self.dt_next * self.n_steps)
209 self.t = self.timestepper.t
211 uend, _stats = self.controller.run(u0=self.x0_pySDC, t0=float(self.t), Tend=float(self.t + self.dt))
213 # update time variables
214 if not np.isclose(self.level.params.dt * self.n_steps, float(self.dt)):
215 self.dt_next = self.level.params.dt
217 self.t = get_sorted(_stats, type='_time', recomputed=False, comm=self.controller_communicator)[-1][1]
219 # update time of the Gusto stepper.
220 # After this step, the Gusto stepper updates its time again to arrive at the correct time
221 if self.timestepper is not None:
222 self.timestepper.t = fd.Constant(self.t - self.dt)
224 self.dt = fd.Constant(self.level.params.dt * self.n_steps)
226 # update stats and output
227 self.stats = {**self.stats, **_stats}
228 x_out.assign(uend.functionspace)