Coverage for pySDC/implementations/problem_classes/generic_spectral.py: 47%
191 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-18 08:18 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-18 08:18 +0000
1from pySDC.core.problem import Problem, WorkCounter
2from pySDC.helpers.spectral_helper import SpectralHelper
3import numpy as np
4from pySDC.core.errors import ParameterError
5from pySDC.helpers.fieldsIO import Rectilinear
8class GenericSpectralLinear(Problem):
9 """
10 Generic class to solve problems of the form M u_t + L u = y, with mass matrix M, linear operator L and some right
11 hand side y using spectral methods.
12 L may contain algebraic conditions, as long as (M + dt L) is invertible.
14 Note that the `__getattr__` method is overloaded to pass requests on to the spectral helper if they are not
15 attributes of this class itself. For instance, you can add a BC by calling `self.spectral.add_BC` or equivalently
16 `self.add_BC`.
18 You can port problems derived from this more or less seamlessly to GPU by using the numerical libraries that are
19 class attributes of the spectral helper. This class will automatically switch the datatype using the `setup_GPU` class method.
21 Attributes:
22 spectral (pySDC.helpers.spectral_helper.SpectralHelper): Spectral helper
23 work_counters (dict): Dictionary for counting work
24 cached_factorizations (dict): Dictionary of cached matrix factorizations for solving
25 L (sparse matrix): Linear operator
26 M (sparse matrix): Mass matrix
27 diff_mask (list): Mask for separating differential and algebraic terms
28 Pl (sparse matrix): Left preconditioner
29 Pr (sparse matrix): Right preconditioner
30 """
32 def setup_GPU(self):
33 """switch to GPU modules"""
34 import cupy as cp
35 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh, imex_cupy_mesh
36 from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
38 self.dtype_u = cupy_mesh
40 GPU_versions = {
41 mesh: cupy_mesh,
42 imex_mesh: imex_cupy_mesh,
43 }
45 self.dtype_f = GPU_versions[self.dtype_f]
47 if self.comm is not None:
48 from pySDC.helpers.NCCL_communicator import NCCLComm
50 if not isinstance(self.comm, NCCLComm):
51 self.__dict__['comm'] = NCCLComm(self.comm)
53 def __init__(
54 self,
55 bases,
56 components,
57 comm=None,
58 Dirichlet_recombination=True,
59 left_preconditioner=True,
60 solver_type='cached_direct',
61 solver_args=None,
62 useGPU=False,
63 max_cached_factorizations=12,
64 spectral_space=True,
65 real_spectral_coefficients=False,
66 debug=False,
67 ):
68 """
69 Base class for problems discretized with spectral methods.
71 Args:
72 bases (list of dictionaries): 1D Bases
73 components (list of strings): Components of the equations
74 comm (mpi4py.Intracomm or None): MPI communicator
75 Dirichlet_recombination (bool): Use Dirichlet recombination in the last axis as right preconditioner
76 left_preconditioner (bool): Reverse the Kronecker product if yes
77 solver_type (str): Solver for linear systems
78 solver_args (dict): Arguments for linear solver
79 useGPU (bool): Run on GPU or CPU
80 max_cached_factorizations (int): Number of matrix decompositions to cache before starting eviction
81 spectral_space (bool): If yes, the solution will not be transformed back after solving and evaluating the RHS, and is expected as input in spectral space to these functions
82 real_spectral_coefficients (bool): If yes, allow only real values in spectral space, otherwise, allow complex.
83 debug (bool): Make additional tests at extra computational cost
84 """
85 solver_args = {} if solver_args is None else solver_args
86 self._makeAttributeAndRegister(
87 'max_cached_factorizations',
88 'useGPU',
89 'solver_type',
90 'solver_args',
91 'left_preconditioner',
92 'Dirichlet_recombination',
93 'comm',
94 'spectral_space',
95 'real_spectral_coefficients',
96 'debug',
97 localVars=locals(),
98 )
99 self.spectral = SpectralHelper(comm=comm, useGPU=useGPU, debug=debug)
101 if useGPU:
102 self.setup_GPU()
103 if self.solver_args is not None:
104 if 'rtol' in self.solver_args.keys():
105 self.solver_args['tol'] = self.solver_args.pop('rtol')
107 for base in bases:
108 self.spectral.add_axis(**base)
109 self.spectral.add_component(components)
111 self.spectral.setup_fft(real_spectral_coefficients)
113 super().__init__(init=self.spectral.init_forward if spectral_space else self.spectral.init)
115 self.work_counters[solver_type] = WorkCounter()
116 self.work_counters['factorizations'] = WorkCounter()
118 self.setup_preconditioner(Dirichlet_recombination, left_preconditioner)
120 self.cached_factorizations = {}
122 def __getattr__(self, name):
123 """
124 Pass requests on to the helper if they are not directly attributes of this class for convenience.
126 Args:
127 name (str): Name of the attribute you want
129 Returns:
130 request
131 """
132 return getattr(self.spectral, name)
134 def _setup_operator(self, LHS):
135 """
136 Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more.
138 Args:
139 LHS (dict): Equations to be added to the operator
141 Returns:
142 sparse linear operator
143 """
144 operator = self.spectral.get_empty_operator_matrix()
145 for line, equation in LHS.items():
146 self.spectral.add_equation_lhs(operator, line, equation)
147 return self.spectral.convert_operator_matrix_to_operator(operator)
149 def setup_L(self, LHS):
150 """
151 Setup the left hand side of the linear operator L and store it in ``self.L``.
153 The argument is meant to be a dictionary with the line you want to write the equation in as the key and the relationship between components as another dictionary. For instance, you can add an algebraic condition capturing a first derivative relationship between u and ux as follows:
155 ```
156 Dx = self.get_differentiation_matrix(axes=(0,))
157 I = self.get_Id()
158 LHS = {'ux': {'u': Dx, 'ux': -I}}
159 self.setup_L(LHS)
160 ```
162 If you put zero as right hand side for the solver in the line for ux, ux will contain the x-derivative of u afterwards.
164 Args:
165 LHS (dict): Dictionary containing the equations.
166 """
167 self.L = self._setup_operator(LHS)
169 def setup_M(self, LHS):
170 '''
171 Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``.
172 '''
173 diff_index = list(LHS.keys())
174 self.diff_mask = [me in diff_index for me in self.components]
175 self.M = self._setup_operator(LHS)
177 def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True):
178 """
179 Get left and right preconditioners.
181 Args:
182 Dirichlet_recombination (bool): Basis conversion for right preconditioner. Useful for Chebychev and Ultraspherical methods. 10/10 would recommend.
183 left_preconditioner (bool): If True, it will interleave the variables and reverse the Kronecker product
184 """
185 sp = self.spectral.sparse_lib
186 N = np.prod(self.init[0][1:])
188 Id = sp.eye(N)
189 Pl_lhs = {comp: {comp: Id} for comp in self.components}
190 self.Pl = self._setup_operator(Pl_lhs)
192 if left_preconditioner:
193 # reverse Kronecker product
195 if self.spectral.useGPU:
196 R = self.Pl.get().tolil() * 0
197 else:
198 R = self.Pl.tolil() * 0
200 for j in range(self.ncomponents):
201 for i in range(N):
202 R[i * self.ncomponents + j, j * N + i] = 1.0
204 self.Pl = self.spectral.sparse_lib.csc_matrix(R)
206 if Dirichlet_recombination and type(self.axes[-1]).__name__ in ['ChebychevHelper, Ultraspherical']:
207 _Pr = self.spectral.get_Dirichlet_recombination_matrix(axis=-1)
208 else:
209 _Pr = Id
211 Pr_lhs = {comp: {comp: _Pr} for comp in self.components}
212 self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T
214 def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
215 """
216 Do an implicit Euler step to solve M u_t + Lu = rhs, with M the mass matrix and L the linear operator as setup by
217 ``GenericSpectralLinear.setup_L`` and ``GenericSpectralLinear.setup_M``.
219 The implicit Euler step is (M - dt L) u = M rhs. Note that M need not be invertible as long as (M + dt*L) is.
220 This means solving with dt=0 to mimic explicit methods does not work for all problems, in particular simple DAEs.
222 Note that by putting M rhs on the right hand side, this function can only solve algebraic conditions equal to
223 zero. If you want something else, it should be easy to overload this function.
224 """
226 sp = self.spectral.sparse_lib
228 if self.spectral_space:
229 rhs_hat = rhs.copy()
230 if u0 is not None:
231 u0_hat = self.Pr.T @ u0.copy().flatten()
232 else:
233 rhs_hat = self.spectral.transform(rhs)
234 if u0 is not None:
235 u0_hat = self.Pr.T @ self.spectral.transform(u0).flatten()
237 if self.useGPU:
238 self.xp.cuda.Device().synchronize()
240 rhs_hat = (self.M @ rhs_hat.flatten()).reshape(rhs_hat.shape)
241 rhs_hat = self.spectral.put_BCs_in_rhs_hat(rhs_hat)
242 rhs_hat = self.Pl @ rhs_hat.flatten()
244 if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct':
245 A = self.M + dt * self.L
246 A = self.Pl @ self.spectral.put_BCs_in_matrix(A) @ self.Pr
248 # import numpy as np
249 # if A.shape[0] < 200:
250 # import matplotlib.pyplot as plt
252 # # M = self.spectral.put_BCs_in_matrix(self.L.copy())
253 # M = A # self.L
254 # im = plt.imshow((M / abs(M)).real)
255 # # im = plt.imshow(np.log10(abs(A.toarray())).real)
256 # # im = plt.imshow(((A.toarray())).real)
257 # plt.colorbar(im)
258 # plt.show()
260 if self.solver_type.lower() == 'cached_direct':
261 if dt not in self.cached_factorizations.keys():
262 if len(self.cached_factorizations) >= self.max_cached_factorizations:
263 self.cached_factorizations.pop(list(self.cached_factorizations.keys())[0])
264 self.logger.debug(f'Evicted matrix factorization for {dt=:.6f} from cache')
265 self.cached_factorizations[dt] = self.spectral.linalg.factorized(A)
266 self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
267 self.work_counters['factorizations']()
269 _sol_hat = self.cached_factorizations[dt](rhs_hat)
270 self.logger.debug(f'Used cached matrix factorization for {dt=:.6f}')
272 elif self.solver_type.lower() == 'direct':
273 _sol_hat = sp.linalg.spsolve(A, rhs_hat)
274 elif self.solver_type.lower() == 'lsqr':
275 lsqr = sp.linalg.lsqr(
276 A,
277 rhs_hat,
278 x0=u0_hat,
279 **self.solver_args,
280 )
281 _sol_hat = lsqr[0]
282 elif self.solver_type.lower() == 'gmres':
283 _sol_hat, _ = sp.linalg.gmres(
284 A,
285 rhs_hat,
286 x0=u0_hat,
287 **self.solver_args,
288 callback=self.work_counters[self.solver_type],
289 callback_type='pr_norm',
290 )
291 elif self.solver_type.lower() == 'gmres+ilu':
292 linalg = self.spectral.linalg
294 if dt not in self.cached_factorizations.keys():
295 if len(self.cached_factorizations) >= self.max_cached_factorizations:
296 to_evict = list(self.cached_factorizations.keys())[0]
297 self.cached_factorizations.pop(to_evict)
298 self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
299 iLU = linalg.spilu(A, drop_tol=dt * 1e-4, fill_factor=100)
300 self.cached_factorizations[dt] = linalg.LinearOperator(A.shape, iLU.solve)
301 self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
302 self.work_counters['factorizations']()
304 _sol_hat, _ = linalg.gmres(
305 A,
306 rhs_hat,
307 x0=u0_hat,
308 **self.solver_args,
309 callback=self.work_counters[self.solver_type],
310 callback_type='pr_norm',
311 M=self.cached_factorizations[dt],
312 )
313 elif self.solver_type.lower() == 'cg':
314 _sol_hat, _ = sp.linalg.cg(
315 A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
316 )
317 else:
318 raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!')
320 sol_hat = self.spectral.u_init_forward
321 sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape)
323 if self.useGPU:
324 self.xp.cuda.Device().synchronize()
326 if self.spectral_space:
327 return sol_hat
328 else:
329 sol = self.spectral.u_init
330 sol[:] = self.spectral.itransform(sol_hat).real
332 if self.spectral.debug:
333 self.spectral.check_BCs(sol)
335 return sol
337 def setUpFieldsIO(self):
338 Rectilinear.setupMPI(
339 comm=self.comm,
340 iLoc=[me.start for me in self.local_slice],
341 nLoc=[me.stop - me.start for me in self.local_slice],
342 )
344 def getOutputFile(self, fileName):
345 self.setUpFieldsIO()
347 coords = [me.get_1dgrid() for me in self.spectral.axes]
348 assert np.allclose([len(me) for me in coords], self.spectral.global_shape[1:])
350 fOut = Rectilinear(np.float64, fileName=fileName)
351 fOut.setHeader(nVar=len(self.components), coords=coords)
352 fOut.initialize()
353 return fOut
355 def processSolutionForOutput(self, u):
356 if self.spectral_space:
357 return np.array(self.itransform(u).real)
358 else:
359 return np.array(u.real)
362def compute_residual_DAE(self, stage=''):
363 """
364 Computation of the residual that does not add u_0 - u_m in algebraic equations.
366 Args:
367 stage (str): The current stage of the step the level belongs to
368 """
370 # get current level and problem description
371 L = self.level
373 # Check if we want to skip the residual computation to gain performance
374 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
375 if stage in self.params.skip_residual_computation:
376 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
377 return None
379 # check if there are new values (e.g. from a sweep)
380 # assert L.status.updated
382 # compute the residual for each node
384 # build QF(u)
385 res_norm = []
386 res = self.integrate()
387 mask = L.prob.diff_mask
388 for m in range(self.coll.num_nodes):
389 res[m][mask] += L.u[0][mask] - L.u[m + 1][mask]
390 # add tau if associated
391 if L.tau[m] is not None:
392 res[m] += L.tau[m]
393 # use abs function from data type here
394 res_norm.append(abs(res[m]))
396 # find maximal residual over the nodes
397 if L.params.residual_type == 'full_abs':
398 L.status.residual = max(res_norm)
399 elif L.params.residual_type == 'last_abs':
400 L.status.residual = res_norm[-1]
401 elif L.params.residual_type == 'full_rel':
402 L.status.residual = max(res_norm) / abs(L.u[0])
403 elif L.params.residual_type == 'last_rel':
404 L.status.residual = res_norm[-1] / abs(L.u[0])
405 else:
406 raise ParameterError(
407 f'residual_type = {L.params.residual_type} not implemented, choose '
408 f'full_abs, last_abs, full_rel or last_rel instead'
409 )
411 # indicate that the residual has seen the new values
412 L.status.updated = False
414 return None
417def compute_residual_DAE_MPI(self, stage=None):
418 """
419 Computation of the residual using the collocation matrix Q
421 Args:
422 stage (str): The current stage of the step the level belongs to
423 """
424 from mpi4py import MPI
426 L = self.level
428 # Check if we want to skip the residual computation to gain performance
429 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
430 if stage in self.params.skip_residual_computation:
431 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
432 return None
434 # compute the residual for each node
436 # build QF(u)
437 res = self.integrate(last_only=L.params.residual_type[:4] == 'last')
438 mask = L.prob.diff_mask
439 res[mask] += L.u[0][mask] - L.u[self.rank + 1][mask]
440 # add tau if associated
441 if L.tau[self.rank] is not None:
442 res += L.tau[self.rank]
443 # use abs function from data type here
444 res_norm = abs(res)
446 # find maximal residual over the nodes
447 if L.params.residual_type == 'full_abs':
448 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX)
449 elif L.params.residual_type == 'last_abs':
450 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1)
451 elif L.params.residual_type == 'full_rel':
452 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX)
453 elif L.params.residual_type == 'last_rel':
454 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1)
455 else:
456 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!')
458 # indicate that the residual has seen the new values
459 L.status.updated = False
461 return None
464def get_extrapolated_error_DAE(self, S, **kwargs):
465 """
466 The extrapolation estimate combines values of u and f from multiple steps to extrapolate and compare to the
467 solution obtained by the time marching scheme. This function can be used in `EstimateExtrapolationError`.
469 Args:
470 S (pySDC.Step): The current step
472 Returns:
473 None
474 """
475 u_ex = self.get_extrapolated_solution(S)
476 diff_mask = S.levels[0].prob.diff_mask
477 if u_ex is not None:
478 S.levels[0].status.error_extrapolation_estimate = (
479 abs((u_ex - S.levels[0].u[-1])[diff_mask]) * self.coeff.prefactor
480 )
481 # print([abs(me) for me in (u_ex - S.levels[0].u[-1]) * self.coeff.prefactor])
482 else:
483 S.levels[0].status.error_extrapolation_estimate = None