Coverage for pySDC/implementations/problem_classes/GenericGusto.py: 0%
122 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
1from pySDC.core.problem import Problem, WorkCounter
2from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh, IMEX_firedrake_mesh
3from gusto.core.labels import (
4 time_derivative,
5 implicit,
6 explicit,
7)
8from firedrake.fml import replace_subject, all_terms, drop
9import firedrake as fd
10import numpy as np
13class GenericGusto(Problem):
14 """
15 Set up solvers based on the equation. Keep in mind that you probably want to use the pySDC-Gusto coupling via
16 the `pySDC_integrator` class in the helpers in order to get spatial methods rather than interfacing with this
17 class directly.
19 Gusto equations work by a residual, which is minimized in nonlinear solvers to obtain the right hand side
20 evaluation or the solution to (IMEX) Euler steps. You control what you solve for by manipulating labeled parts
21 of the residual.
22 """
24 dtype_u = firedrake_mesh
25 dtype_f = firedrake_mesh
26 rhs_n_labels = 1
28 def __init__(
29 self,
30 equation,
31 apply_bcs=True,
32 solver_parameters=None,
33 stop_at_divergence=False,
34 LHS_cache_size=12,
35 residual=None,
36 *active_labels,
37 ):
38 """
39 Initialisation
41 Args:
42 equation (:class:`PrognosticEquation`): the model's equation.
43 apply_bcs (bool, optional): whether to apply the equation's boundary
44 conditions. Defaults to True.
45 solver_params (dict, optional): Solver parameters for the nonlinear variational problems
46 stop_at_divergence (bool, optional): Whether to raise an error when the variational problems do not converge. Defaults to False
47 LHS_cache_size (int, optional): Size of the cache for solvers. Defaults to 12.
48 residual (Firedrake.form, optional): Overwrite the residual of the equation, e.g. after adding spatial methods. Defaults to None.
49 *active_labels (:class:`Label`): labels indicating which terms of
50 the equation to include.
51 """
53 self.equation = equation
54 self.residual = equation.residual if residual is None else residual
55 self.field_name = equation.field_name
56 self.fs = equation.function_space
57 self.idx = None
58 if solver_parameters is None:
59 # default solver parameters
60 solver_parameters = {'ksp_type': 'gmres', 'pc_type': 'bjacobi', 'sub_pc_type': 'ilu'}
61 self.solver_parameters = solver_parameters
62 self.stop_at_divergence = stop_at_divergence
64 # -------------------------------------------------------------------- #
65 # Setup caches
66 # -------------------------------------------------------------------- #
68 self.x_out = fd.Function(self.fs)
69 self.solvers = {}
70 self._u = fd.Function(self.fs)
72 super().__init__(self.fs)
73 self._makeAttributeAndRegister('LHS_cache_size', 'apply_bcs', localVars=locals(), readOnly=True)
74 self.work_counters['rhs'] = WorkCounter()
75 self.work_counters['ksp'] = WorkCounter()
76 self.work_counters['solver_setup'] = WorkCounter()
77 self.work_counters['solver'] = WorkCounter()
79 @property
80 def bcs(self):
81 if not self.apply_bcs:
82 return None
83 else:
84 return self.equation.bcs[self.equation.field_name]
86 def invert_mass_matrix(self, rhs):
87 self._u.assign(rhs.functionspace)
89 if 'mass_matrix' not in self.solvers.keys():
90 mass_form = self.residual.label_map(
91 lambda t: t.has_label(time_derivative),
92 map_if_true=replace_subject(self.x_out, old_idx=self.idx),
93 map_if_false=drop,
94 )
95 rhs_form = self.residual.label_map(
96 lambda t: t.has_label(time_derivative),
97 map_if_true=replace_subject(self._u, old_idx=self.idx),
98 map_if_false=drop,
99 )
101 problem = fd.NonlinearVariationalProblem((mass_form - rhs_form).form, self.x_out, bcs=self.bcs)
102 solver_name = self.field_name + self.__class__.__name__
103 self.solvers['mass_matrix'] = fd.NonlinearVariationalSolver(
104 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name
105 )
106 self.work_counters['solver_setup']()
108 self.solvers['mass_matrix'].solve()
110 return self.dtype_u(self.x_out)
112 def eval_f(self, u, *args):
113 self._u.assign(u.functionspace)
115 if 'eval_rhs' not in self.solvers.keys():
116 residual = self.residual.label_map(
117 lambda t: t.has_label(time_derivative),
118 map_if_false=replace_subject(self._u, old_idx=self.idx),
119 map_if_true=drop,
120 )
121 mass_form = self.residual.label_map(
122 lambda t: t.has_label(time_derivative),
123 map_if_true=replace_subject(self.x_out, old_idx=self.idx),
124 map_if_false=drop,
125 )
127 problem = fd.NonlinearVariationalProblem((mass_form + residual).form, self.x_out, bcs=self.bcs)
128 solver_name = self.field_name + self.__class__.__name__
129 self.solvers['eval_rhs'] = fd.NonlinearVariationalSolver(
130 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name
131 )
132 self.work_counters['solver_setup']()
134 self.solvers['eval_rhs'].solve()
135 self.work_counters['rhs']()
137 return self.dtype_f(self.x_out)
139 def solve_system(self, rhs, factor, u0, *args):
140 self.x_out.assign(u0.functionspace) # set initial guess
141 self._u.assign(rhs.functionspace)
143 if factor not in self.solvers.keys():
144 if len(self.solvers) >= self.LHS_cache_size + self.rhs_n_labels:
145 self.solvers.pop(
146 [me for me in self.solvers.keys() if type(me) in [float, int, np.float64, np.float32]][0]
147 )
149 # setup left hand side (M - factor*f)(u)
150 # put in output variable
151 residual = self.residual.label_map(all_terms, map_if_true=replace_subject(self.x_out, old_idx=self.idx))
152 # multiply f by factor
153 residual = residual.label_map(
154 lambda t: t.has_label(time_derivative), map_if_false=lambda t: fd.Constant(factor) * t
155 )
157 # subtract right hand side
158 mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative), map_if_false=drop)
159 residual -= mass_form.label_map(all_terms, map_if_true=replace_subject(self._u, old_idx=self.idx))
161 # construct solver
162 problem = fd.NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)
163 solver_name = f'{self.field_name}-{self.__class__.__name__}-{factor}'
164 self.solvers[factor] = fd.NonlinearVariationalSolver(
165 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name
166 )
167 self.work_counters['solver_setup']()
169 try:
170 self.solvers[factor].solve()
171 except fd.exceptions.ConvergenceError as error:
172 if self.stop_at_divergence:
173 raise error
174 else:
175 self.logger.debug(error)
177 self.work_counters['ksp'].niter += self.solvers[factor].snes.getLinearSolveIterations()
178 self.work_counters['solver']()
179 return self.dtype_u(self.x_out)
182class GenericGustoImex(GenericGusto):
183 dtype_f = IMEX_firedrake_mesh
184 rhs_n_labels = 2
186 def evaluate_labeled_term(self, u, label):
187 self._u.assign(u.functionspace)
189 if label not in self.solvers.keys():
190 residual = self.residual.label_map(
191 lambda t: t.has_label(label) and not t.has_label(time_derivative),
192 map_if_true=replace_subject(self._u, old_idx=self.idx),
193 map_if_false=drop,
194 )
195 mass_form = self.residual.label_map(
196 lambda t: t.has_label(time_derivative),
197 map_if_true=replace_subject(self.x_out, old_idx=self.idx),
198 map_if_false=drop,
199 )
201 problem = fd.NonlinearVariationalProblem((mass_form + residual).form, self.x_out, bcs=self.bcs)
202 solver_name = self.field_name + self.__class__.__name__
203 self.solvers[label] = fd.NonlinearVariationalSolver(
204 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name
205 )
206 self.work_counters['solver_setup'] = WorkCounter()
208 self.solvers[label].solve()
209 return self.x_out
211 def eval_f(self, u, *args):
212 me = self.dtype_f(self.init)
213 me.impl.assign(self.evaluate_labeled_term(u, implicit))
214 me.expl.assign(self.evaluate_labeled_term(u, explicit))
215 self.work_counters['rhs']()
216 return me
218 def solve_system(self, rhs, factor, u0, *args):
219 self.x_out.assign(u0.functionspace) # set initial guess
220 self._u.assign(rhs.functionspace)
222 if factor not in self.solvers.keys():
223 if len(self.solvers) >= self.LHS_cache_size + self.rhs_n_labels:
224 self.solvers.pop(
225 [me for me in self.solvers.keys() if type(me) in [float, int, np.float64, np.float32]][0]
226 )
228 # setup left hand side (M - factor*f_I)(u)
229 # put in output variable
230 residual = self.residual.label_map(
231 lambda t: t.has_label(time_derivative) or t.has_label(implicit),
232 map_if_true=replace_subject(self.x_out, old_idx=self.idx),
233 map_if_false=drop,
234 )
235 # multiply f_I by factor
236 residual = residual.label_map(
237 lambda t: t.has_label(implicit) and not t.has_label(time_derivative),
238 map_if_true=lambda t: fd.Constant(factor) * t,
239 )
241 # subtract right hand side
242 mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative), map_if_false=drop)
243 residual -= mass_form.label_map(all_terms, map_if_true=replace_subject(self._u, old_idx=self.idx))
245 # construct solver
246 problem = fd.NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs)
247 solver_name = f'{self.field_name}-{self.__class__.__name__}-{factor}'
248 self.solvers[factor] = fd.NonlinearVariationalSolver(
249 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name
250 )
251 self.work_counters['solver_setup'] = WorkCounter()
253 self.solvers[factor].solve()
254 try:
255 self.solvers[factor].solve()
256 except fd.exceptions.ConvergenceError as error:
257 if self.stop_at_divergence:
258 raise error
259 else:
260 self.logger.debug(error)
262 self.work_counters['ksp'].niter += self.solvers[factor].snes.getLinearSolveIterations()
263 self.work_counters['solver']()
264 return self.dtype_u(self.x_out)