Coverage for pySDC/implementations/problem_classes/generic_spectral.py: 48%
176 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 07:15 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 07:15 +0000
1from pySDC.core.problem import Problem, WorkCounter
2from pySDC.helpers.spectral_helper import SpectralHelper
3import numpy as np
4from pySDC.core.errors import ParameterError
7class GenericSpectralLinear(Problem):
8 """
9 Generic class to solve problems of the form M u_t + L u = y, with mass matrix M, linear operator L and some right
10 hand side y using spectral methods.
11 L may contain algebraic conditions, as long as (M + dt L) is invertible.
13 Note that the `__getattr__` method is overloaded to pass requests on to the spectral helper if they are not
14 attributes of this class itself. For instance, you can add a BC by calling `self.spectral.add_BC` or equivalently
15 `self.add_BC`.
17 You can port problems derived from this more or less seamlessly to GPU by using the numerical libraries that are
18 class attributes of the spectral helper. This class will automatically switch the datatype using the `setup_GPU` class method.
20 Attributes:
21 spectral (pySDC.helpers.spectral_helper.SpectralHelper): Spectral helper
22 work_counters (dict): Dictionary for counting work
23 cached_factorizations (dict): Dictionary of cached matrix factorizations for solving
24 L (sparse matrix): Linear operator
25 M (sparse matrix): Mass matrix
26 diff_mask (list): Mask for separating differential and algebraic terms
27 Pl (sparse matrix): Left preconditioner
28 Pr (sparse matrix): Right preconditioner
29 """
31 def setup_GPU(self):
32 """switch to GPU modules"""
33 import cupy as cp
34 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh, imex_cupy_mesh
35 from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
37 self.dtype_u = cupy_mesh
39 GPU_versions = {
40 mesh: cupy_mesh,
41 imex_mesh: imex_cupy_mesh,
42 }
44 self.dtype_f = GPU_versions[self.dtype_f]
46 if self.comm is not None:
47 from pySDC.helpers.NCCL_communicator import NCCLComm
49 if not isinstance(self.comm, NCCLComm):
50 self.__dict__['comm'] = NCCLComm(self.comm)
52 def __init__(
53 self,
54 bases,
55 components,
56 comm=None,
57 Dirichlet_recombination=True,
58 left_preconditioner=True,
59 solver_type='cached_direct',
60 solver_args=None,
61 useGPU=False,
62 max_cached_factorizations=12,
63 spectral_space=True,
64 real_spectral_coefficients=False,
65 debug=False,
66 ):
67 """
68 Base class for problems discretized with spectral methods.
70 Args:
71 bases (list of dictionaries): 1D Bases
72 components (list of strings): Components of the equations
73 comm (mpi4py.Intracomm or None): MPI communicator
74 Dirichlet_recombination (bool): Use Dirichlet recombination in the last axis as right preconditioner
75 left_preconditioner (bool): Reverse the Kronecker product if yes
76 solver_type (str): Solver for linear systems
77 solver_args (dict): Arguments for linear solver
78 useGPU (bool): Run on GPU or CPU
79 max_cached_factorizations (int): Number of matrix decompositions to cache before starting eviction
80 spectral_space (bool): If yes, the solution will not be transformed back after solving and evaluating the RHS, and is expected as input in spectral space to these functions
81 real_spectral_coefficients (bool): If yes, allow only real values in spectral space, otherwise, allow complex.
82 debug (bool): Make additional tests at extra computational cost
83 """
84 solver_args = {} if solver_args is None else solver_args
85 self._makeAttributeAndRegister(
86 'max_cached_factorizations',
87 'useGPU',
88 'solver_type',
89 'solver_args',
90 'left_preconditioner',
91 'Dirichlet_recombination',
92 'comm',
93 'spectral_space',
94 'real_spectral_coefficients',
95 'debug',
96 localVars=locals(),
97 )
98 self.spectral = SpectralHelper(comm=comm, useGPU=useGPU, debug=debug)
100 if useGPU:
101 self.setup_GPU()
102 if self.solver_args is not None:
103 if 'rtol' in self.solver_args.keys():
104 self.solver_args['tol'] = self.solver_args.pop('rtol')
106 for base in bases:
107 self.spectral.add_axis(**base)
108 self.spectral.add_component(components)
110 self.spectral.setup_fft(real_spectral_coefficients)
112 super().__init__(init=self.spectral.init_forward if spectral_space else self.spectral.init)
114 self.work_counters[solver_type] = WorkCounter()
115 self.work_counters['factorizations'] = WorkCounter()
117 self.setup_preconditioner(Dirichlet_recombination, left_preconditioner)
119 self.cached_factorizations = {}
121 def __getattr__(self, name):
122 """
123 Pass requests on to the helper if they are not directly attributes of this class for convenience.
125 Args:
126 name (str): Name of the attribute you want
128 Returns:
129 request
130 """
131 return getattr(self.spectral, name)
133 def _setup_operator(self, LHS):
134 """
135 Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more.
137 Args:
138 LHS (dict): Equations to be added to the operator
140 Returns:
141 sparse linear operator
142 """
143 operator = self.spectral.get_empty_operator_matrix()
144 for line, equation in LHS.items():
145 self.spectral.add_equation_lhs(operator, line, equation)
146 return self.spectral.convert_operator_matrix_to_operator(operator)
148 def setup_L(self, LHS):
149 """
150 Setup the left hand side of the linear operator L and store it in ``self.L``.
152 The argument is meant to be a dictionary with the line you want to write the equation in as the key and the relationship between components as another dictionary. For instance, you can add an algebraic condition capturing a first derivative relationship between u and ux as follows:
154 ```
155 Dx = self.get_differentiation_matrix(axes=(0,))
156 I = self.get_Id()
157 LHS = {'ux': {'u': Dx, 'ux': -I}}
158 self.setup_L(LHS)
159 ```
161 If you put zero as right hand side for the solver in the line for ux, ux will contain the x-derivative of u afterwards.
163 Args:
164 LHS (dict): Dictionary containing the equations.
165 """
166 self.L = self._setup_operator(LHS)
168 def setup_M(self, LHS):
169 '''
170 Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``.
171 '''
172 diff_index = list(LHS.keys())
173 self.diff_mask = [me in diff_index for me in self.components]
174 self.M = self._setup_operator(LHS)
176 def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True):
177 """
178 Get left and right preconditioners.
180 Args:
181 Dirichlet_recombination (bool): Basis conversion for right preconditioner. Useful for Chebychev and Ultraspherical methods. 10/10 would recommend.
182 left_preconditioner (bool): If True, it will interleave the variables and reverse the Kronecker product
183 """
184 sp = self.spectral.sparse_lib
185 N = np.prod(self.init[0][1:])
187 Id = sp.eye(N)
188 Pl_lhs = {comp: {comp: Id} for comp in self.components}
189 self.Pl = self._setup_operator(Pl_lhs)
191 if left_preconditioner:
192 # reverse Kronecker product
194 if self.spectral.useGPU:
195 R = self.Pl.get().tolil() * 0
196 else:
197 R = self.Pl.tolil() * 0
199 for j in range(self.ncomponents):
200 for i in range(N):
201 R[i * self.ncomponents + j, j * N + i] = 1.0
203 self.Pl = self.spectral.sparse_lib.csc_matrix(R)
205 if Dirichlet_recombination and type(self.axes[-1]).__name__ in ['ChebychevHelper, Ultraspherical']:
206 _Pr = self.spectral.get_Dirichlet_recombination_matrix(axis=-1)
207 else:
208 _Pr = Id
210 Pr_lhs = {comp: {comp: _Pr} for comp in self.components}
211 self.Pr = self._setup_operator(Pr_lhs) @ self.Pl.T
213 def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs):
214 """
215 Do an implicit Euler step to solve M u_t + Lu = rhs, with M the mass matrix and L the linear operator as setup by
216 ``GenericSpectralLinear.setup_L`` and ``GenericSpectralLinear.setup_M``.
218 The implicit Euler step is (M - dt L) u = M rhs. Note that M need not be invertible as long as (M + dt*L) is.
219 This means solving with dt=0 to mimic explicit methods does not work for all problems, in particular simple DAEs.
221 Note that by putting M rhs on the right hand side, this function can only solve algebraic conditions equal to
222 zero. If you want something else, it should be easy to overload this function.
223 """
225 sp = self.spectral.sparse_lib
227 if self.spectral_space:
228 rhs_hat = rhs.copy()
229 if u0 is not None:
230 u0_hat = self.Pr.T @ u0.copy().flatten()
231 else:
232 rhs_hat = self.spectral.transform(rhs)
233 if u0 is not None:
234 u0_hat = self.Pr.T @ self.spectral.transform(u0).flatten()
236 if self.useGPU:
237 self.xp.cuda.Device().synchronize()
239 rhs_hat = (self.M @ rhs_hat.flatten()).reshape(rhs_hat.shape)
240 rhs_hat = self.spectral.put_BCs_in_rhs_hat(rhs_hat)
241 rhs_hat = self.Pl @ rhs_hat.flatten()
243 if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct':
244 A = self.M + dt * self.L
245 A = self.Pl @ self.spectral.put_BCs_in_matrix(A) @ self.Pr
247 # import numpy as np
248 # if A.shape[0] < 200:
249 # import matplotlib.pyplot as plt
251 # # M = self.spectral.put_BCs_in_matrix(self.L.copy())
252 # M = A # self.L
253 # im = plt.imshow((M / abs(M)).real)
254 # # im = plt.imshow(np.log10(abs(A.toarray())).real)
255 # # im = plt.imshow(((A.toarray())).real)
256 # plt.colorbar(im)
257 # plt.show()
259 if self.solver_type.lower() == 'cached_direct':
260 if dt not in self.cached_factorizations.keys():
261 if len(self.cached_factorizations) >= self.max_cached_factorizations:
262 self.cached_factorizations.pop(list(self.cached_factorizations.keys())[0])
263 self.logger.debug(f'Evicted matrix factorization for {dt=:.6f} from cache')
264 self.cached_factorizations[dt] = self.spectral.linalg.factorized(A)
265 self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
266 self.work_counters['factorizations']()
268 _sol_hat = self.cached_factorizations[dt](rhs_hat)
269 self.logger.debug(f'Used cached matrix factorization for {dt=:.6f}')
271 elif self.solver_type.lower() == 'direct':
272 _sol_hat = sp.linalg.spsolve(A, rhs_hat)
273 elif self.solver_type.lower() == 'lsqr':
274 lsqr = sp.linalg.lsqr(
275 A,
276 rhs_hat,
277 x0=u0_hat,
278 **self.solver_args,
279 )
280 _sol_hat = lsqr[0]
281 elif self.solver_type.lower() == 'gmres':
282 _sol_hat, _ = sp.linalg.gmres(
283 A,
284 rhs_hat,
285 x0=u0_hat,
286 **self.solver_args,
287 callback=self.work_counters[self.solver_type],
288 callback_type='pr_norm',
289 )
290 elif self.solver_type.lower() == 'gmres+ilu':
291 linalg = self.spectral.linalg
293 if dt not in self.cached_factorizations.keys():
294 if len(self.cached_factorizations) >= self.max_cached_factorizations:
295 to_evict = list(self.cached_factorizations.keys())[0]
296 self.cached_factorizations.pop(to_evict)
297 self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache')
298 iLU = linalg.spilu(A, drop_tol=dt * 1e-4, fill_factor=100)
299 self.cached_factorizations[dt] = linalg.LinearOperator(A.shape, iLU.solve)
300 self.logger.debug(f'Cached matrix factorization for {dt=:.6f}')
301 self.work_counters['factorizations']()
303 _sol_hat, _ = linalg.gmres(
304 A,
305 rhs_hat,
306 x0=u0_hat,
307 **self.solver_args,
308 callback=self.work_counters[self.solver_type],
309 callback_type='pr_norm',
310 M=self.cached_factorizations[dt],
311 )
312 elif self.solver_type.lower() == 'cg':
313 _sol_hat, _ = sp.linalg.cg(
314 A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type]
315 )
316 else:
317 raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!')
319 sol_hat = self.spectral.u_init_forward
320 sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape)
322 if self.useGPU:
323 self.xp.cuda.Device().synchronize()
325 if self.spectral_space:
326 return sol_hat
327 else:
328 sol = self.spectral.u_init
329 sol[:] = self.spectral.itransform(sol_hat).real
331 if self.spectral.debug:
332 self.spectral.check_BCs(sol)
334 return sol
337def compute_residual_DAE(self, stage=''):
338 """
339 Computation of the residual that does not add u_0 - u_m in algebraic equations.
341 Args:
342 stage (str): The current stage of the step the level belongs to
343 """
345 # get current level and problem description
346 L = self.level
348 # Check if we want to skip the residual computation to gain performance
349 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
350 if stage in self.params.skip_residual_computation:
351 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
352 return None
354 # check if there are new values (e.g. from a sweep)
355 # assert L.status.updated
357 # compute the residual for each node
359 # build QF(u)
360 res_norm = []
361 res = self.integrate()
362 mask = L.prob.diff_mask
363 for m in range(self.coll.num_nodes):
364 res[m][mask] += L.u[0][mask] - L.u[m + 1][mask]
365 # add tau if associated
366 if L.tau[m] is not None:
367 res[m] += L.tau[m]
368 # use abs function from data type here
369 res_norm.append(abs(res[m]))
371 # find maximal residual over the nodes
372 if L.params.residual_type == 'full_abs':
373 L.status.residual = max(res_norm)
374 elif L.params.residual_type == 'last_abs':
375 L.status.residual = res_norm[-1]
376 elif L.params.residual_type == 'full_rel':
377 L.status.residual = max(res_norm) / abs(L.u[0])
378 elif L.params.residual_type == 'last_rel':
379 L.status.residual = res_norm[-1] / abs(L.u[0])
380 else:
381 raise ParameterError(
382 f'residual_type = {L.params.residual_type} not implemented, choose '
383 f'full_abs, last_abs, full_rel or last_rel instead'
384 )
386 # indicate that the residual has seen the new values
387 L.status.updated = False
389 return None
392def compute_residual_DAE_MPI(self, stage=None):
393 """
394 Computation of the residual using the collocation matrix Q
396 Args:
397 stage (str): The current stage of the step the level belongs to
398 """
399 from mpi4py import MPI
401 L = self.level
403 # Check if we want to skip the residual computation to gain performance
404 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
405 if stage in self.params.skip_residual_computation:
406 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
407 return None
409 # compute the residual for each node
411 # build QF(u)
412 res = self.integrate(last_only=L.params.residual_type[:4] == 'last')
413 mask = L.prob.diff_mask
414 res[mask] += L.u[0][mask] - L.u[self.rank + 1][mask]
415 # add tau if associated
416 if L.tau[self.rank] is not None:
417 res += L.tau[self.rank]
418 # use abs function from data type here
419 res_norm = abs(res)
421 # find maximal residual over the nodes
422 if L.params.residual_type == 'full_abs':
423 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX)
424 elif L.params.residual_type == 'last_abs':
425 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1)
426 elif L.params.residual_type == 'full_rel':
427 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX)
428 elif L.params.residual_type == 'last_rel':
429 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1)
430 else:
431 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!')
433 # indicate that the residual has seen the new values
434 L.status.updated = False
436 return None
439def get_extrapolated_error_DAE(self, S, **kwargs):
440 """
441 The extrapolation estimate combines values of u and f from multiple steps to extrapolate and compare to the
442 solution obtained by the time marching scheme. This function can be used in `EstimateExtrapolationError`.
444 Args:
445 S (pySDC.Step): The current step
447 Returns:
448 None
449 """
450 u_ex = self.get_extrapolated_solution(S)
451 diff_mask = S.levels[0].prob.diff_mask
452 if u_ex is not None:
453 S.levels[0].status.error_extrapolation_estimate = (
454 abs((u_ex - S.levels[0].u[-1])[diff_mask]) * self.coeff.prefactor
455 )
456 # print([abs(me) for me in (u_ex - S.levels[0].u[-1]) * self.coeff.prefactor])
457 else:
458 S.levels[0].status.error_extrapolation_estimate = None