Coverage for pySDC/helpers/spectral_helper.py: 89%
771 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-18 13:09 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-18 13:09 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
5from functools import partial, wraps
6import logging
9def cache(func):
10 """
11 Decorator for caching return values of functions.
12 This is very similar to `functools.cache`, but without the memory leaks (see
13 https://docs.astral.sh/ruff/rules/cached-instance-method/).
15 Example:
17 .. code-block:: python
19 num_calls = 0
21 @cache
22 def increment(x):
23 num_calls += 1
24 return x + 1
26 increment(0) # returns 1, num_calls = 1
27 increment(1) # returns 2, num_calls = 2
28 increment(0) # returns 1, num_calls = 2
31 Args:
32 func (function): The function you want to cache the return value of
34 Returns:
35 return value of func
36 """
37 attr_cache = f"_{func.__name__}_cache"
39 @wraps(func)
40 def wrapper(self, *args, **kwargs):
41 if not hasattr(self, attr_cache):
42 setattr(self, attr_cache, {})
44 cache = getattr(self, attr_cache)
46 key = (args, frozenset(kwargs.items()))
47 if key in cache:
48 return cache[key]
49 result = func(self, *args, **kwargs)
50 cache[key] = result
51 return result
53 return wrapper
56class SpectralHelper1D:
57 """
58 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
59 all bases need to have.
61 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
62 the code for GPUs.
64 Attributes:
65 N (int): Resolution
66 x0 (float): Coordinate of left boundary
67 x1 (float): Coordinate of right boundary
68 L (float): Length of the domain
69 useGPU (bool): Whether to use GPUs
71 """
73 fft_lib = scipy.fft
74 sparse_lib = scipy.sparse
75 linalg = scipy.sparse.linalg
76 xp = np
77 distributable = False
79 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False):
80 """
81 Constructor
83 Args:
84 N (int): Resolution
85 x0 (float): Coordinate of left boundary
86 x1 (float): Coordinate of right boundary
87 useGPU (bool): Whether to use GPUs
88 useFFTW (bool): Whether to use FFTW for the transforms
89 """
90 self.N = N
91 self.x0 = x0
92 self.x1 = x1
93 self.L = x1 - x0
94 self.useGPU = useGPU
95 self.plans = {}
96 self.logger = logging.getLogger(name=type(self).__name__)
98 if useGPU:
99 self.setup_GPU()
100 else:
101 self.setup_CPU(useFFTW=useFFTW)
103 if useGPU and useFFTW:
104 raise ValueError('Please run either on GPUs or with FFTW, not both!')
106 @classmethod
107 def setup_GPU(cls):
108 """switch to GPU modules"""
109 import cupy as cp
110 import cupyx.scipy.sparse as sparse_lib
111 import cupyx.scipy.sparse.linalg as linalg
112 import cupyx.scipy.fft as fft_lib
113 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
115 cls.xp = cp
116 cls.sparse_lib = sparse_lib
117 cls.linalg = linalg
118 cls.fft_lib = fft_lib
120 @classmethod
121 def setup_CPU(cls, useFFTW=False):
122 """switch to CPU modules"""
124 cls.xp = np
125 cls.sparse_lib = scipy.sparse
126 cls.linalg = scipy.sparse.linalg
128 if useFFTW:
129 from mpi4py_fft import fftw
131 cls.fft_backend = 'fftw'
132 cls.fft_lib = fftw
133 else:
134 cls.fft_backend = 'scipy'
135 cls.fft_lib = scipy.fft
137 cls.fft_comm_backend = 'MPI'
138 cls.dtype = mesh
140 def get_Id(self):
141 """
142 Get identity matrix
144 Returns:
145 sparse diagonal identity matrix
146 """
147 return self.sparse_lib.eye(self.N)
149 def get_zero(self):
150 """
151 Get a matrix with all zeros of the correct size.
153 Returns:
154 sparse matrix with zeros everywhere
155 """
156 return 0 * self.get_Id()
158 def get_differentiation_matrix(self):
159 raise NotImplementedError()
161 def get_integration_matrix(self):
162 raise NotImplementedError()
164 def get_wavenumbers(self):
165 """
166 Get the grid in spectral space
167 """
168 raise NotImplementedError
170 def get_empty_operator_matrix(self, S, O):
171 """
172 Return a matrix of operators to be filled with the connections between the solution components.
174 Args:
175 S (int): Number of components in the solution
176 O (sparse matrix): Zero matrix used for initialization
178 Returns:
179 list of lists containing sparse zeros
180 """
181 return [[O for _ in range(S)] for _ in range(S)]
183 def get_basis_change_matrix(self, *args, **kwargs):
184 """
185 Some spectral discretization change the basis during differentiation. This method can be used to transfer
186 between the various bases.
188 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
189 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
190 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
191 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
192 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
193 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
194 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
196 Returns:
197 sparse bases change matrix
198 """
199 return self.sparse_lib.eye(self.N)
201 def get_BC(self, kind):
202 """
203 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
204 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
205 the matrix are replaced by the boundary conditions as obtained by this method.
207 Args:
208 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
209 individual 1D bases for what is implemented
211 Returns:
212 self.xp.array: Boundary condition
213 """
214 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
216 def get_filter_matrix(self, kmin=0, kmax=None):
217 """
218 Get a bandpass filter.
220 Args:
221 kmin (int): Lower limit of the bandpass filter
222 kmax (int): Upper limit of the bandpass filter
224 Returns:
225 sparse matrix
226 """
228 k = abs(self.get_wavenumbers())
230 kmax = max(k) if kmax is None else kmax
232 mask = self.xp.logical_or(k >= kmax, k < kmin)
234 if self.useGPU:
235 Id = self.get_Id().get()
236 else:
237 Id = self.get_Id()
238 F = Id.tolil()
239 F[:, mask] = 0
240 return F.tocsc()
242 def get_1dgrid(self):
243 """
244 Get the grid in physical space
246 Returns:
247 self.xp.array: Grid
248 """
249 raise NotImplementedError
252class ChebychevHelper(SpectralHelper1D):
253 """
254 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
255 between physical and spectral space by discrete cosine transform.
256 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
257 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
258 be formulated in first order formulation.
260 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068).
261 """
263 def __init__(self, *args, x0=-1, x1=1, **kwargs):
264 """
265 Constructor.
266 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
267 may choose to run on GPUs via the `useGPU` argument.
269 Args:
270 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
271 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
272 """
273 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
274 self.lin_trf_fac = (x1 - x0) / 2
275 self.lin_trf_off = (x1 + x0) / 2
276 super().__init__(*args, x0=x0, x1=x1, **kwargs)
278 self.norm = self.get_norm()
280 def get_1dgrid(self):
281 '''
282 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
283 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
284 need to do an affine transformation before any Chebychev business.
286 Returns:
287 numpy.ndarray: 1D grid
288 '''
289 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
291 def get_wavenumbers(self):
292 """Get the domain in spectral space"""
293 return self.xp.arange(self.N)
295 @cache
296 def get_conv(self, name, N=None):
297 '''
298 Get conversion matrix between different kinds of polynomials. The supported kinds are
299 - T: Chebychev polynomials of first kind
300 - U: Chebychev polynomials of second kind
301 - D: Dirichlet recombination.
303 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
304 Once generates matrices are cached. So feel free to call the method as often as you like.
306 Args:
307 name (str): Conversion code, e.g. 'T2U'
308 N (int): Size of the matrix (optional)
310 Returns:
311 scipy.sparse: Sparse conversion matrix
312 '''
313 N = N if N else self.N
314 sp = self.sparse_lib
316 def get_forward_conv(name):
317 if name == 'T2U':
318 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0
319 mat[:, 0] *= 2
320 elif name == 'D2T':
321 mat = sp.eye(N) - sp.eye(N, k=2)
322 elif name[0] == name[-1]:
323 mat = self.sparse_lib.eye(self.N)
324 else:
325 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
326 return mat
328 try:
329 mat = get_forward_conv(name)
330 except NotImplementedError as E:
331 try:
332 fwd = get_forward_conv(name[::-1])
333 import scipy.sparse as sp
335 if self.sparse_lib == sp:
336 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
337 else:
338 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
339 except NotImplementedError:
340 raise NotImplementedError from E
342 return mat
344 def get_basis_change_matrix(self, conv='T2T', **kwargs):
345 """
346 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
347 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
348 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
349 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
351 Args:
352 conv (str): Conversion code, i.e. T2U
354 Returns:
355 Sparse conversion matrix
356 """
357 return self.get_conv(conv)
359 def get_integration_matrix(self, lbnd=0):
360 """
361 Get matrix for integration
363 Args:
364 lbnd (float): Lower bound for integration, only 0 is currently implemented
366 Returns:
367 Sparse integration matrix
368 """
369 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
370 n = self.xp.arange(self.N)
371 if lbnd == 0:
372 S = S.tocsc()
373 S[0, 1::2] = (
374 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
375 * (-1) ** (self.xp.arange(self.N // 2))
376 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
377 ) * self.lin_trf_fac
378 else:
379 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
380 return S
382 def get_differentiation_matrix(self, p=1):
383 '''
384 Keep in mind that the T2T differentiation matrix is dense.
386 Args:
387 p (int): Derivative you want to compute
389 Returns:
390 numpy.ndarray: Differentiation matrix
391 '''
392 D = self.xp.zeros((self.N, self.N))
393 for j in range(self.N):
394 for k in range(j):
395 D[k, j] = 2 * j * ((j - k) % 2)
397 D[0, :] /= 2
398 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
400 @cache
401 def get_norm(self, N=None):
402 '''
403 Get normalization for converting Chebychev coefficients and DCT
405 Args:
406 N (int, optional): Resolution
408 Returns:
409 self.xp.array: Normalization
410 '''
411 N = self.N if N is None else N
412 norm = self.xp.ones(N) / N
413 norm[0] /= 2
414 return norm
416 def transform(self, u, *args, axes=None, shape=None, **kwargs):
417 """
418 DCT along axes. `kwargs` will be passed on to the FFT library.
420 Args:
421 u: Data you want to transform
422 axes (tuple): Axes you want to transform along
424 Returns:
425 Data in spectral space
426 """
427 axes = axes if axes else tuple(i for i in range(u.ndim))
428 kwargs['s'] = shape
429 kwargs['norm'] = kwargs.get('norm', 'backward')
431 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs)
432 for axis in axes:
434 if self.N < trf.shape[axis]:
435 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
436 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies
437 # in the middle.
438 _trf = self.xp.zeros_like(trf)
439 N = self.N
440 N_pad = _trf.shape[axis] - N
441 end_first_half = N // 2 + 1
443 # copy first "half"
444 su = [slice(None)] * trf.ndim
445 su[axis] = slice(0, end_first_half)
446 _trf[tuple(su)] = trf[tuple(su)]
448 # copy second "half"
449 su = [slice(None)] * u.ndim
450 su[axis] = slice(end_first_half + N_pad, None)
451 s_u = [slice(None)] * u.ndim
452 s_u[axis] = slice(end_first_half, N)
453 _trf[tuple(su)] = trf[tuple(s_u)]
455 # # copy values to be cut
456 # su = [slice(None)] * u.ndim
457 # su[axis] = slice(end_first_half, end_first_half + N_pad)
458 # s_u = [slice(None)] * u.ndim
459 # s_u[axis] = slice(-N_pad, None)
460 # _trf[tuple(su)] = trf[tuple(s_u)]
462 trf = _trf
464 expansion = [np.newaxis for _ in u.shape]
465 expansion[axis] = slice(0, u.shape[axis], 1)
466 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1]
467 norm[: self.N] = self.norm
468 trf *= norm[(*expansion,)]
469 return trf
471 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
472 """
473 Inverse DCT along axis.
475 Args:
476 u: Data you want to transform
477 axes (tuple): Axes you want to transform along
479 Returns:
480 Data in physical space
481 """
482 axes = axes if axes else tuple(i for i in range(u.ndim))
483 kwargs['s'] = shape
484 kwargs['norm'] = kwargs.get('norm', 'backward')
485 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False)
487 for axis in axes:
489 if self.N == u.shape[axis]:
490 _u = u.copy()
491 else:
492 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
493 # added in the middle rather than the end. We need to resort this here and put the padding in the end.
494 N = self.N
495 _u = self.xp.zeros_like(u)
497 # copy first half
498 su = [slice(None)] * u.ndim
499 su[axis] = slice(0, N // 2 + 1)
500 _u[tuple(su)] = u[tuple(su)]
502 # copy second half
503 su = [slice(None)] * u.ndim
504 su[axis] = slice(-(N // 2), None)
505 s_u = [slice(None)] * u.ndim
506 s_u[axis] = slice(N // 2, N // 2 + (N // 2))
507 _u[tuple(s_u)] = u[tuple(su)]
509 if N % 2 == 0:
510 su = [slice(None)] * u.ndim
511 su[axis] = N // 2
512 _u[tuple(su)] *= 2
514 # generate norm
515 expansion = [np.newaxis for _ in u.shape]
516 expansion[axis] = slice(0, u.shape[axis], 1)
517 norm = self.xp.ones(_u.shape[axis])
518 norm[: self.N] = self.norm
519 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N
521 _u /= norm[(*expansion,)]
523 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs)
525 def get_BC(self, kind, **kwargs):
526 """
527 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
528 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
529 set the BC.
531 Args:
532 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
533 """
534 if kind.lower() == 'integral':
535 return self.get_integ_BC_row(**kwargs)
536 elif kind.lower() == 'dirichlet':
537 return self.get_Dirichlet_BC_row(**kwargs)
538 else:
539 return super().get_BC(kind)
541 def get_integ_BC_row(self):
542 """
543 Get a row for generating integral BCs with T polynomials.
544 It returns the values of the integrals of T polynomials over the entire interval.
546 Returns:
547 self.xp.ndarray: Row to put into a matrix
548 """
549 n = self.xp.arange(self.N) + 1
550 me = self.xp.zeros_like(n).astype(float)
551 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
552 me[0] = 2.0
553 return me
555 def get_Dirichlet_BC_row(self, x):
556 """
557 Get a row for generating Dirichlet BCs at x with T polynomials.
558 It returns the values of the T polynomials at x.
560 Args:
561 x (float): Position of the boundary condition
563 Returns:
564 self.xp.ndarray: Row to put into a matrix
565 """
566 if x == -1:
567 return (-1) ** self.xp.arange(self.N)
568 elif x == 1:
569 return self.xp.ones(self.N)
570 elif x == 0:
571 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
572 n[2::4] *= -1
573 return n
574 else:
575 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
577 def get_Dirichlet_recombination_matrix(self):
578 '''
579 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
580 This makes for a good right preconditioner.
582 Returns:
583 scipy.sparse: Sparse conversion matrix
584 '''
585 N = self.N
586 sp = self.sparse_lib
588 return sp.eye(N) - sp.eye(N, k=2)
591class UltrasphericalHelper(ChebychevHelper):
592 """
593 This implementation follows https://doi.org/10.1137/120865458.
594 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
595 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
596 There are also conversion operators from one derivative basis to the next that are sparse.
598 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
599 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
600 """
602 def get_differentiation_matrix(self, p=1):
603 """
604 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
606 Args:
607 p (int): Order of the derivative
609 Returns:
610 sparse differentiation matrix
611 """
612 sp = self.sparse_lib
613 xp = self.xp
614 N = self.N
615 l = p
616 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
618 def get_S(self, lmbda):
619 """
620 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
621 https://doi.org/10.1137/120865458.
623 Args:
624 lmbda (int): Ingoing derivative base
626 Returns:
627 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
628 """
629 N = self.N
631 if lmbda == 0:
632 sp = scipy.sparse
633 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil()
634 mat[:, 0] *= 2
635 else:
636 sp = self.sparse_lib
637 xp = self.xp
638 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
639 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
640 )
642 return self.sparse_lib.csc_matrix(mat)
644 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
645 """
646 Get a conversion matrix from derivative base `p_in` to `p_out`.
648 Args:
649 p_out (int): Resulting derivative base
650 p_in (int): Ingoing derivative base
651 """
652 mat_fwd = self.sparse_lib.eye(self.N)
653 for i in range(min([p_in, p_out]), max([p_in, p_out])):
654 mat_fwd = self.get_S(i) @ mat_fwd
656 if p_out > p_in:
657 return mat_fwd
659 else:
660 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
661 import scipy.sparse as sp
663 if self.useGPU:
664 mat_fwd = mat_fwd.get()
666 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
668 return self.sparse_lib.csc_matrix(mat_bck)
670 def get_integration_matrix(self):
671 """
672 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
673 integration constant such that integration starts from x=-1.
675 Example:
677 .. code-block:: python
679 import numpy as np
680 from pySDC.helpers.spectral_helper import UltrasphericalHelper
682 N = 4
683 helper = UltrasphericalHelper(N)
684 coeffs = np.random.random(N)
685 coeffs[-1] = 0
687 poly = np.polynomial.Chebyshev(coeffs)
689 S = helper.get_integration_matrix()
690 U_hat = S @ coeffs
691 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
693 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
695 Returns:
696 sparse integration matrix
697 """
698 return (
699 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
700 @ self.get_basis_change_matrix(p_out=1, p_in=0)
701 * self.lin_trf_fac
702 )
704 def get_integration_constant(self, u_hat, axis):
705 """
706 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
708 Args:
709 u_hat: Solution in spectral space
710 axis: Axis you want to integrate over
712 Returns:
713 Integration constant, has one less dimension than `u_hat`
714 """
715 slices = [
716 None,
717 ] * u_hat.ndim
718 slices[axis] = slice(1, u_hat.shape[axis])
719 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
722class FFTHelper(SpectralHelper1D):
723 distributable = True
725 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
726 """
727 Constructor.
728 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
729 may choose to run on GPUs via the `useGPU` argument.
731 Args:
732 x0 (float, optional): Coordinate of left boundary
733 x1 (float, optional): Coordinate of right boundary
734 """
735 super().__init__(*args, x0=x0, x1=x1, **kwargs)
737 def get_1dgrid(self):
738 """
739 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
740 """
741 dx = self.L / self.N
742 return self.xp.arange(self.N) * dx + self.x0
744 def get_wavenumbers(self):
745 """
746 Be careful that this ordering is very unintuitive.
747 """
748 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
750 def get_differentiation_matrix(self, p=1):
751 """
752 This matrix is diagonal, allowing to invert concurrently.
754 Args:
755 p (int): Order of the derivative
757 Returns:
758 sparse differentiation matrix
759 """
760 k = self.get_wavenumbers()
762 if self.useGPU:
763 if p > 1:
764 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
765 from scipy.sparse.linalg import matrix_power
767 D = self.sparse_lib.diags(1j * k).get()
768 return self.sparse_lib.csc_matrix(matrix_power(D, p))
769 else:
770 return self.sparse_lib.diags(1j * k)
771 else:
772 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
774 def get_integration_matrix(self, p=1):
775 """
776 Get integration matrix to compute `p`-th integral over the entire domain.
778 Args:
779 p (int): Order of integral you want to compute
781 Returns:
782 sparse integration matrix
783 """
784 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
785 k[0] = 1j * self.L
786 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
788 def get_plan(self, u, forward, *args, **kwargs):
789 if self.fft_lib.__name__ == 'mpi4py_fft.fftw':
790 if 'axes' in kwargs.keys():
791 kwargs['axes'] = tuple(kwargs['axes'])
792 key = (forward, u.shape, args, *(me for me in kwargs.values()))
793 if key in self.plans.keys():
794 return self.plans[key]
795 else:
796 self.logger.debug(f'Generating FFT plan for {key=}')
797 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs)
798 self.plans[key] = transform
800 return self.plans[key]
801 else:
802 if forward:
803 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward'))
804 else:
805 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward'))
807 def transform(self, u, *args, axes=None, shape=None, **kwargs):
808 """
809 FFT along axes. `kwargs` are passed on to the FFT library.
811 Args:
812 u: Data you want to transform
813 axes (tuple): Axes you want to transform over
815 Returns:
816 transformed data
817 """
818 axes = axes if axes else tuple(i for i in range(u.ndim))
819 kwargs['s'] = shape
820 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs)
821 return plan(u, *args, axes=axes, **kwargs)
823 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
824 """
825 Inverse FFT.
827 Args:
828 u: Data you want to transform
829 axes (tuple): Axes over which to transform
831 Returns:
832 transformed data
833 """
834 axes = axes if axes else tuple(i for i in range(u.ndim))
835 kwargs['s'] = shape
836 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs)
837 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes])
839 def get_BC(self, kind):
840 """
841 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
842 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
843 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
845 Args:
846 kind ('integral' or 'nyquist'): Kind of BC
848 Returns:
849 self.xp.ndarray: Boundary condition row
850 """
851 if kind.lower() == 'integral':
852 return self.get_integ_BC_row()
853 elif kind.lower() == 'nyquist':
854 assert (
855 self.N % 2 == 0
856 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
857 BC = self.xp.zeros(self.N)
858 BC[self.get_Nyquist_mode_index()] = 1
859 return BC
860 else:
861 return super().get_BC(kind)
863 def get_Nyquist_mode_index(self):
864 """
865 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
866 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
867 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
868 after.
870 Returns:
871 int: Index of the Nyquist mode
872 """
873 k = self.get_wavenumbers()
874 Nyquist_mode = min(k)
875 return self.xp.where(k == Nyquist_mode)[0][0]
877 def get_integ_BC_row(self):
878 """
879 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
880 """
881 me = self.xp.zeros(self.N)
882 me[0] = self.L / self.N
883 return me
886class SpectralHelper:
887 """
888 This class has three functions:
889 - Easily assemble matrices containing multiple equations
890 - Direct product of 1D bases to solve problems in more dimensions
891 - Distribute the FFTs to facilitate concurrency.
893 Attributes:
894 comm (mpi4py.Intracomm): MPI communicator
895 debug (bool): Perform additional checks at extra computational cost
896 useGPU (bool): Whether to use GPUs
897 axes (list): List of 1D bases
898 components (list): List of strings of the names of components in the equations
899 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
900 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
901 BCs (sparse matrix): Matrix containing only the BCs
902 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
903 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
904 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
905 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
906 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
907 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
908 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
909 init (tuple): This is the same `init` that is used throughout the problem classes
910 init_forward (tuple): This is the equivalent of `init` in spectral space
911 """
913 xp = np
914 fft_lib = scipy.fft
915 sparse_lib = scipy.sparse
916 linalg = scipy.sparse.linalg
917 dtype = mesh
918 fft_backend = 'scipy'
919 fft_comm_backend = 'MPI'
921 @classmethod
922 def setup_GPU(cls):
923 """switch to GPU modules"""
924 import cupy as cp
925 import cupyx.scipy.sparse as sparse_lib
926 import cupyx.scipy.sparse.linalg as linalg
927 import cupyx.scipy.fft as fft_lib
928 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
930 cls.xp = cp
931 cls.sparse_lib = sparse_lib
932 cls.linalg = linalg
934 cls.fft_lib = fft_lib
935 cls.fft_backend = 'cupyx-scipy'
936 cls.fft_comm_backend = 'NCCL'
938 cls.dtype = cupy_mesh
940 @classmethod
941 def setup_CPU(cls, useFFTW=False):
942 """switch to CPU modules"""
944 cls.xp = np
945 cls.sparse_lib = scipy.sparse
946 cls.linalg = scipy.sparse.linalg
948 if useFFTW:
949 from mpi4py_fft import fftw
951 cls.fft_backend = 'fftw'
952 cls.fft_lib = fftw
953 else:
954 cls.fft_backend = 'scipy'
955 cls.fft_lib = scipy.fft
957 cls.fft_comm_backend = 'MPI'
958 cls.dtype = mesh
960 def __init__(self, comm=None, useGPU=False, debug=False):
961 """
962 Constructor
964 Args:
965 comm (mpi4py.Intracomm): MPI communicator
966 useGPU (bool): Whether to use GPUs
967 debug (bool): Perform additional checks at extra computational cost
968 """
969 self.comm = comm
970 self.debug = debug
971 self.useGPU = useGPU
973 if useGPU:
974 self.setup_GPU()
975 else:
976 self.setup_CPU()
978 self.axes = []
979 self.components = []
981 self.full_BCs = []
982 self.BC_mat = None
983 self.BCs = None
985 self.fft_cache = {}
986 self.fft_dealias_shape_cache = {}
988 self.logger = logging.getLogger(name='Spectral Discretization')
989 if debug:
990 self.logger.setLevel(logging.DEBUG)
992 @property
993 def u_init(self):
994 """
995 Get empty data container in physical space
996 """
997 return self.dtype(self.init)
999 @property
1000 def u_init_forward(self):
1001 """
1002 Get empty data container in spectral space
1003 """
1004 return self.dtype(self.init_forward)
1006 @property
1007 def u_init_physical(self):
1008 """
1009 Get empty data container in physical space
1010 """
1011 return self.dtype(self.init_physical)
1013 @property
1014 def shape(self):
1015 """
1016 Get shape of individual solution component
1017 """
1018 return self.init[0][1:]
1020 @property
1021 def ndim(self):
1022 return len(self.axes)
1024 @property
1025 def ncomponents(self):
1026 return len(self.components)
1028 @property
1029 def V(self):
1030 """
1031 Get domain volume
1032 """
1033 return np.prod([me.L for me in self.axes])
1035 def add_axis(self, base, *args, **kwargs):
1036 """
1037 Add an axis to the domain by deciding on suitable 1D base.
1038 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
1039 bases for possible arguments.
1041 Args:
1042 base (str): 1D spectral method
1043 """
1044 kwargs['useGPU'] = self.useGPU
1046 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
1047 self.axes.append(ChebychevHelper(*args, **kwargs))
1048 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
1049 self.axes.append(FFTHelper(*args, **kwargs))
1050 elif base.lower() in ['ultraspherical', 'gegenbauer']:
1051 self.axes.append(UltrasphericalHelper(*args, **kwargs))
1052 else:
1053 raise NotImplementedError(f'{base=!r} is not implemented!')
1054 self.axes[-1].xp = self.xp
1055 self.axes[-1].sparse_lib = self.sparse_lib
1057 def add_component(self, name):
1058 """
1059 Add solution component(s).
1061 Args:
1062 name (str or list of strings): Name(s) of component(s)
1063 """
1064 if type(name) in [list, tuple]:
1065 for me in name:
1066 self.add_component(me)
1067 elif type(name) in [str]:
1068 if name in self.components:
1069 raise Exception(f'{name=!r} is already added to this problem!')
1070 self.components.append(name)
1071 else:
1072 raise NotImplementedError
1074 def index(self, name):
1075 """
1076 Get the index of component `name`.
1078 Args:
1079 name (str or list of strings): Name(s) of component(s)
1081 Returns:
1082 int: Index of the component
1083 """
1084 if type(name) in [str, int]:
1085 return self.components.index(name)
1086 elif type(name) in [list, tuple]:
1087 return (self.index(me) for me in name)
1088 else:
1089 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
1091 def get_empty_operator_matrix(self, diag=False):
1092 """
1093 Return a matrix of operators to be filled with the connections between the solution components.
1095 Args:
1096 diag (bool): Whether operator is block-diagonal
1098 Returns:
1099 list containing sparse zeros
1100 """
1101 S = len(self.components)
1102 O = self.get_Id() * 0
1103 if diag:
1104 return [O for _ in range(S)]
1105 else:
1106 return [[O for _ in range(S)] for _ in range(S)]
1108 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
1109 """
1110 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
1111 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
1112 Typically, the last line that does not contain a BC is the best choice.
1113 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1115 Args:
1116 axis (int): Axis you want to add the BC to
1117 kind (str): kind of BC, e.g. Dirichlet
1118 line (int): Line you want the BC to go in
1119 scalar (bool): Put the BC in all space positions in the other direction
1121 Returns:
1122 sparse matrix containing the BC
1123 """
1124 sp = scipy.sparse
1126 base = self.axes[axis]
1128 BC = sp.eye(base.N).tolil() * 0
1129 if self.useGPU:
1130 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1131 else:
1132 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1134 ndim = len(self.axes)
1135 if ndim == 1:
1136 mat = self.sparse_lib.csc_matrix(BC)
1137 elif ndim == 2:
1138 axis2 = (axis + 1) % ndim
1140 if scalar:
1141 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1142 else:
1143 _Id = self.axes[axis2].get_Id()
1145 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1147 mats = [
1148 None,
1149 ] * ndim
1150 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1151 mats[axis2] = Id
1152 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1153 elif ndim == 3:
1154 mats = [
1155 None,
1156 ] * ndim
1158 for ax in range(ndim):
1159 if ax == axis:
1160 continue
1162 if scalar:
1163 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1)))
1164 else:
1165 _Id = self.axes[ax].get_Id()
1167 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax)
1169 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1171 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
1172 else:
1173 raise NotImplementedError(
1174 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
1175 )
1176 mat = self.eliminate_zeros(mat)
1177 return mat
1179 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1180 """
1181 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1182 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1183 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1185 Args:
1186 component (str): Name of the component the BC should act on
1187 equation (str): Name of the equation for the component you want to put the BC in
1188 axis (int): Axis you want to add the BC to
1189 kind (str): kind of BC, e.g. Dirichlet
1190 v: Value of the BC
1191 line (int): Line you want the BC to go in
1192 scalar (bool): Put the BC in all space positions in the other direction
1193 """
1194 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1195 _BC = self.eliminate_zeros(_BC)
1196 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1198 if scalar:
1199 slices = [self.index(equation)] + [
1200 0,
1201 ] * self.ndim
1202 slices[axis + 1] = line
1203 else:
1204 slices = (
1205 [self.index(equation)]
1206 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1207 + [line]
1208 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1209 )
1210 N = self.axes[axis].N
1211 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]:
1212 self.BC_rhs_mask[(*slices,)] = False
1214 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1215 """
1216 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1217 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1218 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1220 Args:
1221 component (str): Name of the component the BC should act on
1222 equation (str): Name of the equation for the component you want to put the BC in
1223 axis (int): Axis you want to add the BC to
1224 kind (str): kind of BC, e.g. Dirichlet
1225 v: Value of the BC
1226 line (int): Line you want the BC to go in
1227 scalar (bool): Put the BC in all space positions in the other direction
1228 """
1229 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1230 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1231 self.full_BCs += [
1232 {
1233 'component': component,
1234 'equation': equation,
1235 'axis': axis,
1236 'kind': kind,
1237 'v': v,
1238 'line': line,
1239 'scalar': scalar,
1240 **kwargs,
1241 }
1242 ]
1244 if scalar:
1245 slices = [self.index(equation)] + [
1246 0,
1247 ] * self.ndim
1248 slices[axis + 1] = line
1249 if self.comm:
1250 if self.comm.rank == 0:
1251 self.BC_rhs_mask[(*slices,)] = True
1252 else:
1253 self.BC_rhs_mask[(*slices,)] = True
1254 else:
1255 slices = [self.index(equation), *self.global_slice(True)]
1256 N = self.axes[axis].N
1257 if (N + line) % N in self.get_indices(True)[axis]:
1258 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1259 self.BC_rhs_mask[(*slices,)] = True
1261 def setup_BCs(self):
1262 """
1263 Convert the list of lists of BCs to the boundary condition operator.
1264 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1265 condition. This method sets up a suitable sparse matrix to do this.
1266 """
1267 sp = self.sparse_lib
1268 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1269 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1271 diags = self.xp.ones(self.BCs.shape[0])
1272 diags[self.BC_zero_index] = 0
1273 self.BC_line_zero_matrix = sp.diags(diags)
1275 # prepare BCs in spectral space to easily add to the RHS
1276 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1277 self.rhs_BCs_hat = self.transform(rhs_BCs)
1279 def check_BCs(self, u):
1280 """
1281 Check that the solution satisfies the boundary conditions
1283 Args:
1284 u: The solution you want to check
1285 """
1286 assert self.ndim < 3
1287 for axis in range(self.ndim):
1288 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1290 if len(BCs) > 0:
1291 u_hat = self.transform(u, axes=(axis - self.ndim,))
1292 for BC in BCs:
1293 kwargs = {
1294 key: value
1295 for key, value in BC.items()
1296 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1297 }
1299 if axis == 0:
1300 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1301 elif axis == 1:
1302 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1303 want = BC['v']
1304 assert self.xp.allclose(
1305 get, want
1306 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1308 def put_BCs_in_matrix(self, A):
1309 """
1310 Put the boundary conditions in a matrix by replacing rows with BCs.
1311 """
1312 return self.BC_line_zero_matrix @ A + self.BCs
1314 def put_BCs_in_rhs_hat(self, rhs_hat):
1315 """
1316 Put the BCs in the right hand side in spectral space for solving.
1317 This function needs no transforms and caches a mask for faster subsequent use.
1319 Args:
1320 rhs_hat: Right hand side in spectral space
1322 Returns:
1323 rhs in spectral space with BCs
1324 """
1325 if not hasattr(self, '_rhs_hat_zero_mask'):
1326 """
1327 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1328 by the boundary conditions. The mask is then cached.
1329 """
1330 self._rhs_hat_zero_mask = self.newDistArray().astype(bool)
1332 for axis in range(self.ndim):
1333 for bc in self.full_BCs:
1334 if axis == bc['axis']:
1335 slices = [self.index(bc['equation']), *self.global_slice(True)]
1336 N = self.axes[axis].N
1337 line = bc['line']
1338 if (N + line) % N in self.get_indices(True)[axis]:
1339 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1340 self._rhs_hat_zero_mask[(*slices,)] = True
1342 rhs_hat[self._rhs_hat_zero_mask] = 0
1343 return rhs_hat + self.rhs_BCs_hat
1345 def put_BCs_in_rhs(self, rhs):
1346 """
1347 Put the BCs in the right hand side for solving.
1348 This function will transform along each axis individually and add all BCs in that axis.
1349 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1351 Args:
1352 rhs: Right hand side in physical space
1354 Returns:
1355 rhs in physical space with BCs
1356 """
1357 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1359 ndim = self.ndim
1361 for axis in range(ndim):
1362 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1364 for bc in self.full_BCs:
1366 if axis == bc['axis']:
1367 _slice = [self.index(bc['equation']), *self.global_slice(True)]
1369 N = self.axes[axis].N
1370 line = bc['line']
1371 if (N + line) % N in self.get_indices(True)[axis]:
1372 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1373 _rhs_hat[(*_slice,)] = bc['v']
1375 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1377 return rhs
1379 def add_equation_lhs(self, A, equation, relations):
1380 """
1381 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1382 that you will convert to an operator later.
1384 Example:
1385 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1386 preconditioning:
1388 .. code-block:: python
1389 helper = SpectralHelper()
1391 helper.add_axis(base='chebychev', N=8)
1392 helper.add_component(['u', 'ux'])
1393 helper.setup_fft()
1395 I = helper.get_Id()
1396 Dx = helper.get_differentiation_matrix(axes=(0,))
1397 T2U = helper.get_basis_change_matrix('T2U')
1399 L_lhs = {
1400 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1401 'u': {'ux': -(T2U @ Dx)},
1402 }
1404 operator = helper.get_empty_operator_matrix()
1405 for line, equation in L_lhs.items():
1406 helper.add_equation_lhs(operator, line, equation)
1408 L = helper.convert_operator_matrix_to_operator(operator)
1410 Args:
1411 A (list of lists of sparse matrices): The operator to be
1412 equation (str): The equation of the component you want this in
1413 relations: (dict): Relations between quantities
1414 """
1415 for k, v in relations.items():
1416 A[self.index(equation)][self.index(k)] = v
1418 def eliminate_zeros(self, A):
1419 """
1420 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
1421 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
1422 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.
1424 Args:
1425 A: sparse matrix to be pruned
1427 Returns:
1428 CSC sparse matrix
1429 """
1430 if self.useGPU:
1431 A = A.get()
1432 A = A.tocsc()
1433 A.eliminate_zeros()
1434 if self.useGPU:
1435 A = self.sparse_lib.csc_matrix(A)
1436 return A
1438 def convert_operator_matrix_to_operator(self, M):
1439 """
1440 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1441 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1443 Args:
1444 M (list of lists of sparse matrices): The operator to be
1446 Returns:
1447 sparse linear operator
1448 """
1449 if len(self.components) == 1:
1450 op = M[0][0]
1451 else:
1452 op = self.sparse_lib.bmat(M, format='csc')
1454 op = self.eliminate_zeros(op)
1455 return op
1457 def get_wavenumbers(self):
1458 """
1459 Get grid in spectral space
1460 """
1461 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))]
1462 return self.xp.meshgrid(*grids, indexing='ij')
1464 def get_grid(self, forward_output=False):
1465 """
1466 Get grid in physical space
1467 """
1468 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1469 return self.xp.meshgrid(*grids, indexing='ij')
1471 def get_indices(self, forward_output=True):
1472 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1474 @cache
1475 def get_pfft(self, axes=None, padding=None, grid=None):
1476 if self.ndim == 1 or self.comm is None:
1477 return None
1478 from mpi4py_fft import PFFT
1480 axes = tuple(i for i in range(self.ndim)) if axes is None else axes
1481 padding = list(padding if padding else [1.0 for _ in range(self.ndim)])
1483 def no_transform(u, *args, **kwargs):
1484 return u
1486 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)}
1487 for i in axes:
1488 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform)
1490 # "transform" all axes to ensure consistent shapes.
1491 # Transform non-distributable axes last to ensure they are aligned
1492 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes))
1493 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted(
1494 [axis for axis in _axes if self.axes[axis].distributable]
1495 + [axis for axis in range(self.ndim) if axis not in _axes]
1496 )
1498 pfft = PFFT(
1499 comm=self.comm,
1500 shape=self.global_shape[1:],
1501 axes=_axes, # TODO: control the order of the transforms better
1502 dtype='D',
1503 collapse=False,
1504 backend=self.fft_backend,
1505 comm_backend=self.fft_comm_backend,
1506 padding=padding,
1507 transforms=transforms,
1508 grid=grid,
1509 )
1510 return pfft
1512 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1513 """
1514 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1516 Args:
1517 axes (tuple): Axes you want to transform over
1518 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1519 padding (tuple): Padding for dealiasing
1520 shape (tuple): Shape of the transform
1522 Returns:
1523 transform
1524 """
1525 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1526 shape = self.global_shape[1:] if shape is None else shape
1527 padding = (
1528 [
1529 1,
1530 ]
1531 * self.ndim
1532 if padding is None
1533 else padding
1534 )
1535 key = (axes, direction, tuple(padding), tuple(shape))
1537 if key not in self.fft_cache.keys():
1538 if self.comm is None:
1539 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1541 if direction == 'forward':
1542 self.fft_cache[key] = self.xp.fft.fftn
1543 elif direction == 'backward':
1544 self.fft_cache[key] = self.xp.fft.ifftn
1545 elif direction == 'object':
1546 self.fft_cache[key] = None
1547 else:
1548 if direction == 'object':
1549 from mpi4py_fft import PFFT
1551 _fft = PFFT(
1552 comm=self.comm,
1553 shape=shape,
1554 axes=sorted(axes),
1555 dtype='D',
1556 collapse=False,
1557 backend=self.fft_backend,
1558 comm_backend=self.fft_comm_backend,
1559 padding=padding,
1560 )
1561 else:
1562 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1564 if direction == 'forward':
1565 self.fft_cache[key] = _fft.forward
1566 elif direction == 'backward':
1567 self.fft_cache[key] = _fft.backward
1568 elif direction == 'object':
1569 self.fft_cache[key] = _fft
1571 return self.fft_cache[key]
1573 def local_slice(self, forward_output=True):
1574 if self.fft_obj:
1575 return self.get_pfft().local_slice(forward_output=forward_output)
1576 else:
1577 return [slice(0, me.N) for me in self.axes]
1579 def global_slice(self, forward_output=True):
1580 if self.fft_obj:
1581 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)]
1582 else:
1583 return self.local_slice(forward_output=forward_output)
1585 def setup_fft(self, real_spectral_coefficients=False):
1586 """
1587 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1588 This must also be called before setting up any BCs.
1590 Args:
1591 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1592 """
1593 if len(self.components) == 0:
1594 self.add_component('u')
1596 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1598 axes = tuple(i for i in range(len(self.axes)))
1599 self.fft_obj = self.get_pfft(axes=axes)
1601 self.init = (
1602 np.empty(shape=self.global_shape)[
1603 (
1604 ...,
1605 *self.local_slice(False),
1606 )
1607 ].shape,
1608 self.comm,
1609 np.dtype('float'),
1610 )
1611 self.init_physical = (
1612 np.empty(shape=self.global_shape)[
1613 (
1614 ...,
1615 *self.local_slice(False),
1616 )
1617 ].shape,
1618 self.comm,
1619 np.dtype('float'),
1620 )
1621 self.init_forward = (
1622 np.empty(shape=self.global_shape)[
1623 (
1624 ...,
1625 *self.local_slice(True),
1626 )
1627 ].shape,
1628 self.comm,
1629 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1630 )
1632 self.BC_mat = self.get_empty_operator_matrix()
1633 self.BC_rhs_mask = self.newDistArray().astype(bool)
1635 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False):
1636 """
1637 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but
1638 takes care of all the solution components in the tensor.
1639 """
1640 if self.comm is None:
1641 return self.xp.zeros(self.init[0], dtype=self.init[2])
1642 from mpi4py_fft.distarray import DistArray
1644 pfft = pfft if pfft else self.get_pfft()
1645 if pfft is None:
1646 if forward_output:
1647 return self.u_init_forward
1648 else:
1649 return self.u_init
1651 global_shape = pfft.global_shape(forward_output)
1652 p0 = pfft.pencil[forward_output]
1653 if forward_output is True:
1654 dtype = pfft.forward.output_array.dtype
1655 else:
1656 dtype = pfft.forward.input_array.dtype
1657 global_shape = (self.ncomponents,) * rank + global_shape
1659 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
1660 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
1661 else:
1662 darraycls = DistArray
1664 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank)
1665 return z.v if view else z
1667 def infer_alignment(self, u, forward_output, padding=None, **kwargs):
1668 if self.comm is None:
1669 return [0]
1671 def _alignment(pfft):
1672 _arr = self.newDistArray(pfft, forward_output=forward_output)
1673 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]]
1674 return _aligned_axes
1676 if padding is None:
1677 pfft = self.get_pfft(**kwargs)
1678 aligned_axes = _alignment(pfft)
1679 else:
1680 if self.ndim == 2:
1681 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)]
1682 elif self.ndim == 3:
1683 padding_options = [
1684 (1.0, 1.0, padding[2]),
1685 (1.0, padding[1], 1.0),
1686 (padding[0], 1.0, 1.0),
1687 (1.0, padding[1], padding[2]),
1688 (padding[0], 1.0, padding[2]),
1689 (padding[0], padding[1], 1.0),
1690 padding,
1691 (1.0, 1.0, 1.0),
1692 ]
1693 else:
1694 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!')
1695 for _padding in padding_options:
1696 pfft = self.get_pfft(padding=_padding, **kwargs)
1697 aligned_axes = _alignment(pfft)
1698 if len(aligned_axes) > 0:
1699 self.logger.debug(
1700 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}'
1701 )
1702 break
1704 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!'
1705 return aligned_axes
1707 def redistribute(self, u, axis, forward_output, **kwargs):
1708 if self.comm is None:
1709 return u
1711 pfft = self.get_pfft(**kwargs)
1712 _arr = self.newDistArray(pfft, forward_output=forward_output)
1714 if 'Dist' in type(u).__name__ and False:
1715 try:
1716 u.redistribute(out=_arr)
1717 return _arr
1718 except AssertionError:
1719 pass
1721 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs)
1722 for alignment in u_alignment:
1723 _arr = _arr.redistribute(alignment)
1724 if _arr.shape == u.shape:
1725 _arr[...] = u
1726 return _arr.redistribute(axis)
1728 raise Exception(
1729 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}'
1730 )
1732 def transform(self, u, *args, axes=None, padding=None, **kwargs):
1733 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1735 if pfft is None:
1736 axes = axes if axes else tuple(i for i in range(self.ndim))
1737 u_hat = u.copy()
1738 for i in axes:
1739 _axis = 1 + i if i >= 0 else i
1740 u_hat = self.axes[i].transform(u_hat, axes=(_axis,))
1741 return u_hat
1743 _in = self.newDistArray(pfft, forward_output=False, rank=1)
1744 _out = self.newDistArray(pfft, forward_output=True, rank=1)
1746 if _in.shape == u.shape:
1747 _in[...] = u
1748 else:
1749 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs)
1751 for i in range(self.ncomponents):
1752 pfft.forward(_in[i], _out[i], normalize=False)
1754 if padding is not None:
1755 _out /= np.prod(padding)
1756 return _out
1758 def itransform(self, u, *args, axes=None, padding=None, **kwargs):
1759 if padding is not None:
1760 assert all(
1761 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim)
1762 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer'
1764 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1765 if pfft is None:
1766 axes = axes if axes else tuple(i for i in range(self.ndim))
1767 u_hat = u.copy()
1768 for i in axes:
1769 _axis = 1 + i if i >= 0 else i
1770 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,))
1771 return u_hat
1773 _in = self.newDistArray(pfft, forward_output=True, rank=1)
1774 _out = self.newDistArray(pfft, forward_output=False, rank=1)
1776 if _in.shape == u.shape:
1777 _in[...] = u
1778 else:
1779 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs)
1781 for i in range(self.ncomponents):
1782 pfft.backward(_in[i], _out[i], normalize=True)
1784 if padding is not None:
1785 _out *= np.prod(padding)
1786 return _out
1788 def get_local_slice_of_1D_matrix(self, M, axis):
1789 """
1790 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1791 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can
1792 use this method to get the part corresponding to the modes carried by this rank.
1794 Args:
1795 M (sparse matrix): Global 1D matrix you want to get the local version of
1796 axis (int): Direction in which you want the local version. You will get the global matrix in other directions.
1798 Returns:
1799 sparse local matrix
1800 """
1801 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]]
1803 def expand_matrix_ND(self, matrix, aligned):
1804 sp = self.sparse_lib
1805 axes = np.delete(np.arange(self.ndim), aligned)
1806 ndim = len(axes) + 1
1808 if ndim == 1:
1809 mat = matrix
1810 elif ndim == 2:
1811 axis = axes[0]
1812 I1D = sp.eye(self.axes[axis].N)
1814 mats = [None] * ndim
1815 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1816 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1818 mat = sp.kron(*mats)
1819 elif ndim == 3:
1821 mats = [None] * ndim
1822 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1823 for axis in axes:
1824 I1D = sp.eye(self.axes[axis].N)
1825 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1827 mat = sp.kron(mats[0], sp.kron(*mats[1:]))
1829 else:
1830 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
1832 mat = self.eliminate_zeros(mat)
1833 return mat
1835 def get_filter_matrix(self, axis, **kwargs):
1836 """
1837 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1838 admissible.
1840 Returns:
1841 sparse bandpass matrix
1842 """
1843 if self.ndim == 1:
1844 return self.axes[0].get_filter_matrix(**kwargs)
1846 mats = [base.get_Id() for base in self.axes]
1847 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1848 return self.sparse_lib.kron(*mats)
1850 def get_differentiation_matrix(self, axes, **kwargs):
1851 """
1852 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1854 Args:
1855 axes (tuple): Axes along which to differentiate.
1857 Returns:
1858 sparse differentiation matrix
1859 """
1860 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
1861 for axis in axes[1:]:
1862 _D = self.axes[axis].get_differentiation_matrix(**kwargs)
1863 D = D @ self.expand_matrix_ND(_D, axis)
1865 return D
1867 def get_integration_matrix(self, axes):
1868 """
1869 Get integration matrix to integrate along specified axis.
1871 Args:
1872 axes (tuple): Axes along which to integrate over.
1874 Returns:
1875 sparse integration matrix
1876 """
1877 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
1878 for axis in axes[1:]:
1879 _S = self.axes[axis].get_integration_matrix()
1880 S = S @ self.expand_matrix_ND(_S, axis)
1882 return S
1884 def get_Id(self):
1885 """
1886 Get identity matrix
1888 Returns:
1889 sparse identity matrix
1890 """
1891 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
1892 for axis in range(1, self.ndim):
1893 _I = self.axes[axis].get_Id()
1894 I = I @ self.expand_matrix_ND(_I, axis)
1895 return I
1897 def get_Dirichlet_recombination_matrix(self, axis=-1):
1898 """
1899 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1901 Args:
1902 axis (int): Axis you discretized with Chebychev
1904 Returns:
1905 sparse matrix
1906 """
1907 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1908 return self.expand_matrix_ND(C1D, axis)
1910 def get_basis_change_matrix(self, axes=None, **kwargs):
1911 """
1912 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
1913 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
1915 Args:
1916 axes (tuple): Axes along which to change basis.
1918 Returns:
1919 sparse basis change matrix
1920 """
1921 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1923 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
1924 for axis in axes[1:]:
1925 _C = self.axes[axis].get_basis_change_matrix(**kwargs)
1926 C = C @ self.expand_matrix_ND(_C, axis)
1928 return C