Coverage for pySDC/implementations/problem_classes/generic_ND_FD.py: 94%

71 statements

, created at 2024-09-09 14:59 +0000

1#!/usr/bin/env python3

2# -*- coding: utf-8 -*-

3"""

4Created on Sat Feb 11 22:39:30 2023

5"""

6import numpy as np

7import scipy.sparse as sp

8from scipy.sparse.linalg import gmres, spsolve, cg

10from pySDC.core.errors import ProblemError

11from pySDC.core.problem import Problem, WorkCounter

12from pySDC.helpers import problem_helper

13from pySDC.implementations.datatype_classes.mesh import mesh

16class GenericNDimFinDiff(Problem):

17 r"""

18 Base class for finite difference spatial discretisation in :math:N dimensions

20 .. math::

21 \frac{d u}{dt} = A u,

23 where :math:A \in \mathbb{R}^{nN \times nN} is a matrix arising from finite difference discretisation of spatial

24 derivatives with :math:n degrees of freedom per dimension and :math:N dimensions. This generic class follows the MOL

25 (method-of-lines) approach and can be used to discretize partial differential equations such as the advection

26 equation and the heat equation.

28 Parameters

29 ----------

30 nvars : int, optional

31 Spatial resolution for the ND problem. For :math:N = 2,

32 set nvars=(16, 16).

33 coeff : float, optional

34 Factor for finite difference matrix :math:A.

35 derivative : int, optional

36 Order of the spatial derivative.

37 freq : tuple of int, optional

38 Spatial frequency, can be a tuple.

39 stencil_type : str, optional

40 Stencil type for finite differences.

41 order : int, optional

42 Order of accuracy of the finite difference discretization.

43 lintol : float, optional

44 Tolerance for spatial solver.

45 liniter : int, optional

46 Maximum number of iterations for linear solver.

47 solver_type : str, optional

48 Type of solver. Can be 'direct', 'GMRES' or 'CG'.

49 bc : str or tuple of 2 string, optional

50 Type of boundary conditions. Default is 'periodic'.

51 To define two different types of boundary condition for each side,

52 you can use a tuple, for instance bc=("dirichlet", "neumann")

53 uses Dirichlet BC on the left side, and Neumann BC on the right side.

54 bcParams : dict, optional

55 Parameters for boundary conditions, that can contains those keys :

57 - **val** : value for the boundary value (Dirichlet) or derivative

58 (Neumann), default to 0

59 - **reduce** : if true, reduce the order of the A matrix close to the

60 boundary. If false (default), use shifted stencils close to the

61 boundary.

62 - **neumann_bc_order** : finite difference order that should be used

63 for the neumann BC derivative. If None (default), uses the same

64 order as the discretization for A.

66 Default is None, which takes the default values for each parameters.

67 You can also define a tuple to set different parameters for each

68 side.

70 Attributes

71 ----------

72 A : sparse matrix (CSC)

73 FD discretization matrix of the ND operator.

74 Id : sparse matrix (CSC)

75 Identity matrix of the same dimension as A.

76 xvalues : np.1darray

77 Values of spatial grid.

78 """

80 dtype_u = mesh

81 dtype_f = mesh

83 def __init__(

84 self,

85 nvars=512,

86 coeff=1.0,

87 derivative=1,

88 freq=2,

89 stencil_type='center',

90 order=2,

91 lintol=1e-12,

92 liniter=10000,

93 solver_type='direct',

94 bc='periodic',

95 bcParams=None,

96 ):

97 # make sure parameters have the correct types

98 if type(nvars) not in [int, tuple]:

99 raise ProblemError('nvars should be either tuple or int')

100 if type(freq) not in [int, tuple]:

101 raise ProblemError('freq should be either tuple or int')

103 # transforms nvars into a tuple

104 if type(nvars) is int:

105 nvars = (nvars,)

107 # automatically determine ndim from nvars

108 ndim = len(nvars)

109 if ndim > 3:

110 raise ProblemError(f'can work with up to three dimensions, got {ndim}')

112 # eventually extend freq to other dimension

113 if type(freq) is int:

114 freq = (freq,) * ndim

115 if len(freq) != ndim:

116 raise ProblemError(f'len(freq)={len(freq)}, different to ndim={ndim}')

118 # check values for freq and nvars

119 for f in freq:

120 if ndim == 1 and f == -1:

121 # use Gaussian initial solution in 1D

122 bc = 'periodic'

123 break

124 if f % 2 != 0 and bc == 'periodic':

125 raise ProblemError('need even number of frequencies due to periodic BCs')

126 for nvar in nvars:

127 if nvar % 2 != 0 and bc == 'periodic':

128 raise ProblemError('the setup requires nvars = 2^p per dimension')

129 if (nvar + 1) % 2 != 0 and bc == 'dirichlet-zero':

130 raise ProblemError('setup requires nvars = 2^p - 1')

131 if ndim > 1 and nvars[1:] != nvars[:-1]:

132 raise ProblemError('need a square domain, got %s' % nvars)

134 # invoke super init, passing number of dofs

135 super().__init__(init=(nvars[0] if ndim == 1 else nvars, None, np.dtype('float64')))

137 dx, xvalues = problem_helper.get_1d_grid(size=nvars[0], bc=bc, left_boundary=0.0, right_boundary=1.0)

139 self.A, _ = problem_helper.get_finite_difference_matrix(

140 derivative=derivative,

141 order=order,

142 stencil_type=stencil_type,

143 dx=dx,

144 size=nvars[0],

145 dim=ndim,

146 bc=bc,

147 )

148 self.A *= coeff

150 self.xvalues = xvalues

151 self.Id = sp.eye(np.prod(nvars), format='csc')

153 # store attribute and register them as parameters

154 self._makeAttributeAndRegister('nvars', 'stencil_type', 'order', 'bc', localVars=locals(), readOnly=True)

155 self._makeAttributeAndRegister('freq', 'lintol', 'liniter', 'solver_type', localVars=locals())

157 if self.solver_type != 'direct':

158 self.work_counters[self.solver_type] = WorkCounter()

160 @property

161 def ndim(self):

162 """Number of dimensions of the spatial problem"""

163 return len(self.nvars)

165 @property

166 def dx(self):

167 """Size of the mesh (in all dimensions)"""

168 return self.xvalues[1] - self.xvalues[0]

170 @property

171 def grids(self):

172 """ND grids associated to the problem"""

173 x = self.xvalues

174 if self.ndim == 1:

175 return x

176 if self.ndim == 2:

177 return x[None, :], x[:, None]

178 if self.ndim == 3:

179 return x[None, :, None], x[:, None, None], x[None, None, :]

181 @classmethod

182 def get_default_sweeper_class(cls):

183 from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit

185 return generic_implicit

187 def eval_f(self, u, t):

188 """

189 Routine to evaluate the right-hand side of the problem.

191 Parameters

192 ----------

193 u : dtype_u

194 Current values.

195 t : float

196 Current time.

198 Returns

199 -------

200 f : dtype_f

201 Values of the right-hand side of the problem.

202 """

203 f = self.f_init

204 f[:] = self.A.dot(u.flatten()).reshape(self.nvars)

205 return f

207 def solve_system(self, rhs, factor, u0, t):

208 r"""

209 Simple linear solver for :math:(I-factor\cdot A)\vec{u}=\vec{rhs}.

211 Parameters

212 ----------

213 rhs : dtype_f

214 Right-hand side for the linear system.

215 factor : float

216 Abbrev. for the local stepsize (or any other factor required).

217 u0 : dtype_u

218 Initial guess for the iterative solver.

219 t : float

220 Current time (e.g. for time-dependent BCs).

222 Returns

223 -------

224 sol : dtype_u

225 The solution of the linear solver.

226 """

227 solver_type, Id, A, nvars, lintol, liniter, sol = (

228 self.solver_type,

229 self.Id,

230 self.A,

231 self.nvars,

232 self.lintol,

233 self.liniter,

234 self.u_init,

235 )

237 if solver_type == 'direct':

238 sol[:] = spsolve(Id - factor * A, rhs.flatten()).reshape(nvars)

239 elif solver_type == 'GMRES':

240 sol[:] = gmres(

241 Id - factor * A,

242 rhs.flatten(),

243 x0=u0.flatten(),

244 rtol=lintol,

245 maxiter=liniter,

246 atol=0,

247 callback=self.work_counters[solver_type],

248 callback_type='legacy',

249 )[0].reshape(nvars)

250 elif solver_type == 'CG':

251 sol[:] = cg(

252 Id - factor * A,

253 rhs.flatten(),

254 x0=u0.flatten(),

255 rtol=lintol,

256 maxiter=liniter,

257 atol=0,

258 callback=self.work_counters[solver_type],

259 )[0].reshape(nvars)

260 else:

261 raise ValueError(f'solver type "{solver_type}" not known in generic advection-diffusion implementation!')

263 return sol