Coverage for pySDC/implementations/problem_classes/GenericGusto.py: 0%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 10:09 +0000

1from pySDC.core.problem import Problem, WorkCounter 

2from pySDC.implementations.datatype_classes.firedrake_mesh import firedrake_mesh, IMEX_firedrake_mesh 

3from gusto.core.labels import ( 

4 time_derivative, 

5 implicit, 

6 explicit, 

7) 

8from firedrake.fml import replace_subject, all_terms, drop 

9import firedrake as fd 

10import numpy as np 

11 

12 

13class GenericGusto(Problem): 

14 """ 

15 Set up solvers based on the equation. Keep in mind that you probably want to use the pySDC-Gusto coupling via 

16 the `pySDC_integrator` class in the helpers in order to get spatial methods rather than interfacing with this 

17 class directly. 

18 

19 Gusto equations work by a residual, which is minimized in nonlinear solvers to obtain the right hand side 

20 evaluation or the solution to (IMEX) Euler steps. You control what you solve for by manipulating labeled parts 

21 of the residual. 

22 """ 

23 

24 dtype_u = firedrake_mesh 

25 dtype_f = firedrake_mesh 

26 rhs_n_labels = 1 

27 

28 def __init__( 

29 self, 

30 equation, 

31 apply_bcs=True, 

32 solver_parameters=None, 

33 stop_at_divergence=False, 

34 LHS_cache_size=12, 

35 residual=None, 

36 *active_labels, 

37 ): 

38 """ 

39 Initialisation 

40 

41 Args: 

42 equation (:class:`PrognosticEquation`): the model's equation. 

43 apply_bcs (bool, optional): whether to apply the equation's boundary 

44 conditions. Defaults to True. 

45 solver_params (dict, optional): Solver parameters for the nonlinear variational problems 

46 stop_at_divergence (bool, optional): Whether to raise an error when the variational problems do not converge. Defaults to False 

47 LHS_cache_size (int, optional): Size of the cache for solvers. Defaults to 12. 

48 residual (Firedrake.form, optional): Overwrite the residual of the equation, e.g. after adding spatial methods. Defaults to None. 

49 *active_labels (:class:`Label`): labels indicating which terms of 

50 the equation to include. 

51 """ 

52 

53 self.equation = equation 

54 self.residual = equation.residual if residual is None else residual 

55 self.field_name = equation.field_name 

56 self.fs = equation.function_space 

57 self.idx = None 

58 if solver_parameters is None: 

59 # default solver parameters 

60 solver_parameters = {'ksp_type': 'gmres', 'pc_type': 'bjacobi', 'sub_pc_type': 'ilu'} 

61 self.solver_parameters = solver_parameters 

62 self.stop_at_divergence = stop_at_divergence 

63 

64 # -------------------------------------------------------------------- # 

65 # Setup caches 

66 # -------------------------------------------------------------------- # 

67 

68 self.x_out = fd.Function(self.fs) 

69 self.solvers = {} 

70 self._u = fd.Function(self.fs) 

71 

72 super().__init__(self.fs) 

73 self._makeAttributeAndRegister('LHS_cache_size', 'apply_bcs', localVars=locals(), readOnly=True) 

74 self.work_counters['rhs'] = WorkCounter() 

75 self.work_counters['ksp'] = WorkCounter() 

76 self.work_counters['solver_setup'] = WorkCounter() 

77 self.work_counters['solver'] = WorkCounter() 

78 

79 @property 

80 def bcs(self): 

81 if not self.apply_bcs: 

82 return None 

83 else: 

84 return self.equation.bcs[self.equation.field_name] 

85 

86 def invert_mass_matrix(self, rhs): 

87 self._u.assign(rhs.functionspace) 

88 

89 if 'mass_matrix' not in self.solvers.keys(): 

90 mass_form = self.residual.label_map( 

91 lambda t: t.has_label(time_derivative), 

92 map_if_true=replace_subject(self.x_out, old_idx=self.idx), 

93 map_if_false=drop, 

94 ) 

95 rhs_form = self.residual.label_map( 

96 lambda t: t.has_label(time_derivative), 

97 map_if_true=replace_subject(self._u, old_idx=self.idx), 

98 map_if_false=drop, 

99 ) 

100 

101 problem = fd.NonlinearVariationalProblem((mass_form - rhs_form).form, self.x_out, bcs=self.bcs) 

102 solver_name = self.field_name + self.__class__.__name__ 

103 self.solvers['mass_matrix'] = fd.NonlinearVariationalSolver( 

104 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name 

105 ) 

106 self.work_counters['solver_setup']() 

107 

108 self.solvers['mass_matrix'].solve() 

109 

110 return self.dtype_u(self.x_out) 

111 

112 def eval_f(self, u, *args): 

113 self._u.assign(u.functionspace) 

114 

115 if 'eval_rhs' not in self.solvers.keys(): 

116 residual = self.residual.label_map( 

117 lambda t: t.has_label(time_derivative), 

118 map_if_false=replace_subject(self._u, old_idx=self.idx), 

119 map_if_true=drop, 

120 ) 

121 mass_form = self.residual.label_map( 

122 lambda t: t.has_label(time_derivative), 

123 map_if_true=replace_subject(self.x_out, old_idx=self.idx), 

124 map_if_false=drop, 

125 ) 

126 

127 problem = fd.NonlinearVariationalProblem((mass_form + residual).form, self.x_out, bcs=self.bcs) 

128 solver_name = self.field_name + self.__class__.__name__ 

129 self.solvers['eval_rhs'] = fd.NonlinearVariationalSolver( 

130 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name 

131 ) 

132 self.work_counters['solver_setup']() 

133 

134 self.solvers['eval_rhs'].solve() 

135 self.work_counters['rhs']() 

136 

137 return self.dtype_f(self.x_out) 

138 

139 def solve_system(self, rhs, factor, u0, *args): 

140 self.x_out.assign(u0.functionspace) # set initial guess 

141 self._u.assign(rhs.functionspace) 

142 

143 if factor not in self.solvers.keys(): 

144 if len(self.solvers) >= self.LHS_cache_size + self.rhs_n_labels: 

145 self.solvers.pop( 

146 [me for me in self.solvers.keys() if type(me) in [float, int, np.float64, np.float32]][0] 

147 ) 

148 

149 # setup left hand side (M - factor*f)(u) 

150 # put in output variable 

151 residual = self.residual.label_map(all_terms, map_if_true=replace_subject(self.x_out, old_idx=self.idx)) 

152 # multiply f by factor 

153 residual = residual.label_map( 

154 lambda t: t.has_label(time_derivative), map_if_false=lambda t: fd.Constant(factor) * t 

155 ) 

156 

157 # subtract right hand side 

158 mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative), map_if_false=drop) 

159 residual -= mass_form.label_map(all_terms, map_if_true=replace_subject(self._u, old_idx=self.idx)) 

160 

161 # construct solver 

162 problem = fd.NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs) 

163 solver_name = f'{self.field_name}-{self.__class__.__name__}-{factor}' 

164 self.solvers[factor] = fd.NonlinearVariationalSolver( 

165 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name 

166 ) 

167 self.work_counters['solver_setup']() 

168 

169 try: 

170 self.solvers[factor].solve() 

171 except fd.exceptions.ConvergenceError as error: 

172 if self.stop_at_divergence: 

173 raise error 

174 else: 

175 self.logger.debug(error) 

176 

177 self.work_counters['ksp'].niter += self.solvers[factor].snes.getLinearSolveIterations() 

178 self.work_counters['solver']() 

179 return self.dtype_u(self.x_out) 

180 

181 

182class GenericGustoImex(GenericGusto): 

183 dtype_f = IMEX_firedrake_mesh 

184 rhs_n_labels = 2 

185 

186 def evaluate_labeled_term(self, u, label): 

187 self._u.assign(u.functionspace) 

188 

189 if label not in self.solvers.keys(): 

190 residual = self.residual.label_map( 

191 lambda t: t.has_label(label) and not t.has_label(time_derivative), 

192 map_if_true=replace_subject(self._u, old_idx=self.idx), 

193 map_if_false=drop, 

194 ) 

195 mass_form = self.residual.label_map( 

196 lambda t: t.has_label(time_derivative), 

197 map_if_true=replace_subject(self.x_out, old_idx=self.idx), 

198 map_if_false=drop, 

199 ) 

200 

201 problem = fd.NonlinearVariationalProblem((mass_form + residual).form, self.x_out, bcs=self.bcs) 

202 solver_name = self.field_name + self.__class__.__name__ 

203 self.solvers[label] = fd.NonlinearVariationalSolver( 

204 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name 

205 ) 

206 self.work_counters['solver_setup'] = WorkCounter() 

207 

208 self.solvers[label].solve() 

209 return self.x_out 

210 

211 def eval_f(self, u, *args): 

212 me = self.dtype_f(self.init) 

213 me.impl.assign(self.evaluate_labeled_term(u, implicit)) 

214 me.expl.assign(self.evaluate_labeled_term(u, explicit)) 

215 self.work_counters['rhs']() 

216 return me 

217 

218 def solve_system(self, rhs, factor, u0, *args): 

219 self.x_out.assign(u0.functionspace) # set initial guess 

220 self._u.assign(rhs.functionspace) 

221 

222 if factor not in self.solvers.keys(): 

223 if len(self.solvers) >= self.LHS_cache_size + self.rhs_n_labels: 

224 self.solvers.pop( 

225 [me for me in self.solvers.keys() if type(me) in [float, int, np.float64, np.float32]][0] 

226 ) 

227 

228 # setup left hand side (M - factor*f_I)(u) 

229 # put in output variable 

230 residual = self.residual.label_map( 

231 lambda t: t.has_label(time_derivative) or t.has_label(implicit), 

232 map_if_true=replace_subject(self.x_out, old_idx=self.idx), 

233 map_if_false=drop, 

234 ) 

235 # multiply f_I by factor 

236 residual = residual.label_map( 

237 lambda t: t.has_label(implicit) and not t.has_label(time_derivative), 

238 map_if_true=lambda t: fd.Constant(factor) * t, 

239 ) 

240 

241 # subtract right hand side 

242 mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative), map_if_false=drop) 

243 residual -= mass_form.label_map(all_terms, map_if_true=replace_subject(self._u, old_idx=self.idx)) 

244 

245 # construct solver 

246 problem = fd.NonlinearVariationalProblem(residual.form, self.x_out, bcs=self.bcs) 

247 solver_name = f'{self.field_name}-{self.__class__.__name__}-{factor}' 

248 self.solvers[factor] = fd.NonlinearVariationalSolver( 

249 problem, solver_parameters=self.solver_parameters, options_prefix=solver_name 

250 ) 

251 self.work_counters['solver_setup'] = WorkCounter() 

252 

253 self.solvers[factor].solve() 

254 try: 

255 self.solvers[factor].solve() 

256 except fd.exceptions.ConvergenceError as error: 

257 if self.stop_at_divergence: 

258 raise error 

259 else: 

260 self.logger.debug(error) 

261 

262 self.work_counters['ksp'].niter += self.solvers[factor].snes.getLinearSolveIterations() 

263 self.work_counters['solver']() 

264 return self.dtype_u(self.x_out)