Coverage for pySDC/helpers/spectral_helper.py: 90%
765 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 11:36 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 11:36 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
5from functools import partial, wraps
6import logging
9def cache(func):
10 """
11 Decorator for caching return values of functions.
12 This is very similar to `functools.cache`, but without the memory leaks (see
13 https://docs.astral.sh/ruff/rules/cached-instance-method/).
15 Example:
17 .. code-block:: python
19 num_calls = 0
21 @cache
22 def increment(x):
23 num_calls += 1
24 return x + 1
26 increment(0) # returns 1, num_calls = 1
27 increment(1) # returns 2, num_calls = 2
28 increment(0) # returns 1, num_calls = 2
31 Args:
32 func (function): The function you want to cache the return value of
34 Returns:
35 return value of func
36 """
37 attr_cache = f"_{func.__name__}_cache"
39 @wraps(func)
40 def wrapper(self, *args, **kwargs):
41 if not hasattr(self, attr_cache):
42 setattr(self, attr_cache, {})
44 cache = getattr(self, attr_cache)
46 key = (args, frozenset(kwargs.items()))
47 if key in cache:
48 return cache[key]
49 result = func(self, *args, **kwargs)
50 cache[key] = result
51 return result
53 return wrapper
56class SpectralHelper1D:
57 """
58 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
59 all bases need to have.
61 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
62 the code for GPUs.
64 Attributes:
65 N (int): Resolution
66 x0 (float): Coordinate of left boundary
67 x1 (float): Coordinate of right boundary
68 L (float): Length of the domain
69 useGPU (bool): Whether to use GPUs
71 """
73 fft_lib = scipy.fft
74 sparse_lib = scipy.sparse
75 linalg = scipy.sparse.linalg
76 xp = np
77 distributable = False
79 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False):
80 """
81 Constructor
83 Args:
84 N (int): Resolution
85 x0 (float): Coordinate of left boundary
86 x1 (float): Coordinate of right boundary
87 useGPU (bool): Whether to use GPUs
88 useFFTW (bool): Whether to use FFTW for the transforms
89 """
90 self.N = N
91 self.x0 = x0
92 self.x1 = x1
93 self.L = x1 - x0
94 self.useGPU = useGPU
95 self.plans = {}
96 self.logger = logging.getLogger(name=type(self).__name__)
98 if useGPU:
99 self.setup_GPU()
100 else:
101 self.setup_CPU(useFFTW=useFFTW)
103 if useGPU and useFFTW:
104 raise ValueError('Please run either on GPUs or with FFTW, not both!')
106 @classmethod
107 def setup_GPU(cls):
108 """switch to GPU modules"""
109 import cupy as cp
110 import cupyx.scipy.sparse as sparse_lib
111 import cupyx.scipy.sparse.linalg as linalg
112 import cupyx.scipy.fft as fft_lib
113 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
115 cls.xp = cp
116 cls.sparse_lib = sparse_lib
117 cls.linalg = linalg
118 cls.fft_lib = fft_lib
120 @classmethod
121 def setup_CPU(cls, useFFTW=False):
122 """switch to CPU modules"""
124 cls.xp = np
125 cls.sparse_lib = scipy.sparse
126 cls.linalg = scipy.sparse.linalg
128 if useFFTW:
129 from mpi4py_fft import fftw
131 cls.fft_backend = 'fftw'
132 cls.fft_lib = fftw
133 else:
134 cls.fft_backend = 'scipy'
135 cls.fft_lib = scipy.fft
137 cls.fft_comm_backend = 'MPI'
138 cls.dtype = mesh
140 def get_Id(self):
141 """
142 Get identity matrix
144 Returns:
145 sparse diagonal identity matrix
146 """
147 return self.sparse_lib.eye(self.N)
149 def get_zero(self):
150 """
151 Get a matrix with all zeros of the correct size.
153 Returns:
154 sparse matrix with zeros everywhere
155 """
156 return 0 * self.get_Id()
158 def get_differentiation_matrix(self):
159 raise NotImplementedError()
161 def get_integration_matrix(self):
162 raise NotImplementedError()
164 def get_wavenumbers(self):
165 """
166 Get the grid in spectral space
167 """
168 raise NotImplementedError
170 def get_empty_operator_matrix(self, S, O):
171 """
172 Return a matrix of operators to be filled with the connections between the solution components.
174 Args:
175 S (int): Number of components in the solution
176 O (sparse matrix): Zero matrix used for initialization
178 Returns:
179 list of lists containing sparse zeros
180 """
181 return [[O for _ in range(S)] for _ in range(S)]
183 def get_basis_change_matrix(self, *args, **kwargs):
184 """
185 Some spectral discretization change the basis during differentiation. This method can be used to transfer
186 between the various bases.
188 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
189 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
190 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
191 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
192 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
193 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
194 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
196 Returns:
197 sparse bases change matrix
198 """
199 return self.sparse_lib.eye(self.N)
201 def get_BC(self, kind):
202 """
203 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
204 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
205 the matrix are replaced by the boundary conditions as obtained by this method.
207 Args:
208 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
209 individual 1D bases for what is implemented
211 Returns:
212 self.xp.array: Boundary condition
213 """
214 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
216 def get_filter_matrix(self, kmin=0, kmax=None):
217 """
218 Get a bandpass filter.
220 Args:
221 kmin (int): Lower limit of the bandpass filter
222 kmax (int): Upper limit of the bandpass filter
224 Returns:
225 sparse matrix
226 """
228 k = abs(self.get_wavenumbers())
230 kmax = max(k) if kmax is None else kmax
232 mask = self.xp.logical_or(k >= kmax, k < kmin)
234 if self.useGPU:
235 Id = self.get_Id().get()
236 else:
237 Id = self.get_Id()
238 F = Id.tolil()
239 F[:, mask] = 0
240 return F.tocsc()
242 def get_1dgrid(self):
243 """
244 Get the grid in physical space
246 Returns:
247 self.xp.array: Grid
248 """
249 raise NotImplementedError
252class ChebychevHelper(SpectralHelper1D):
253 """
254 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
255 between physical and spectral space by discrete cosine transform.
256 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
257 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
258 be formulated in first order formulation.
260 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068).
261 """
263 def __init__(self, *args, x0=-1, x1=1, **kwargs):
264 """
265 Constructor.
266 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
267 may choose to run on GPUs via the `useGPU` argument.
269 Args:
270 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
271 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
272 """
273 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
274 self.lin_trf_fac = (x1 - x0) / 2
275 self.lin_trf_off = (x1 + x0) / 2
276 super().__init__(*args, x0=x0, x1=x1, **kwargs)
278 self.norm = self.get_norm()
280 def get_1dgrid(self):
281 '''
282 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
283 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
284 need to do an affine transformation before any Chebychev business.
286 Returns:
287 numpy.ndarray: 1D grid
288 '''
289 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
291 def get_wavenumbers(self):
292 """Get the domain in spectral space"""
293 return self.xp.arange(self.N)
295 @cache
296 def get_conv(self, name, N=None):
297 '''
298 Get conversion matrix between different kinds of polynomials. The supported kinds are
299 - T: Chebychev polynomials of first kind
300 - U: Chebychev polynomials of second kind
301 - D: Dirichlet recombination.
303 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
304 Once generates matrices are cached. So feel free to call the method as often as you like.
306 Args:
307 name (str): Conversion code, e.g. 'T2U'
308 N (int): Size of the matrix (optional)
310 Returns:
311 scipy.sparse: Sparse conversion matrix
312 '''
313 N = N if N else self.N
314 sp = self.sparse_lib
315 xp = self.xp
317 def get_forward_conv(name):
318 if name == 'T2U':
319 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)).tocsc() / 2.0
320 mat[:, 0] *= 2
321 elif name == 'D2T':
322 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
323 elif name[0] == name[-1]:
324 mat = self.sparse_lib.eye(self.N)
325 else:
326 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
327 return mat
329 try:
330 mat = get_forward_conv(name)
331 except NotImplementedError as E:
332 try:
333 fwd = get_forward_conv(name[::-1])
334 import scipy.sparse as sp
336 if self.sparse_lib == sp:
337 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
338 else:
339 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
340 except NotImplementedError:
341 raise NotImplementedError from E
343 return mat
345 def get_basis_change_matrix(self, conv='T2T', **kwargs):
346 """
347 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
348 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
349 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
350 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
352 Args:
353 conv (str): Conversion code, i.e. T2U
355 Returns:
356 Sparse conversion matrix
357 """
358 return self.get_conv(conv)
360 def get_integration_matrix(self, lbnd=0):
361 """
362 Get matrix for integration
364 Args:
365 lbnd (float): Lower bound for integration, only 0 is currently implemented
367 Returns:
368 Sparse integration matrix
369 """
370 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
371 n = self.xp.arange(self.N)
372 if lbnd == 0:
373 S = S.tocsc()
374 S[0, 1::2] = (
375 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
376 * (-1) ** (self.xp.arange(self.N // 2))
377 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
378 ) * self.lin_trf_fac
379 else:
380 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
381 return S
383 def get_differentiation_matrix(self, p=1):
384 '''
385 Keep in mind that the T2T differentiation matrix is dense.
387 Args:
388 p (int): Derivative you want to compute
390 Returns:
391 numpy.ndarray: Differentiation matrix
392 '''
393 D = self.xp.zeros((self.N, self.N))
394 for j in range(self.N):
395 for k in range(j):
396 D[k, j] = 2 * j * ((j - k) % 2)
398 D[0, :] /= 2
399 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
401 @cache
402 def get_norm(self, N=None):
403 '''
404 Get normalization for converting Chebychev coefficients and DCT
406 Args:
407 N (int, optional): Resolution
409 Returns:
410 self.xp.array: Normalization
411 '''
412 N = self.N if N is None else N
413 norm = self.xp.ones(N) / N
414 norm[0] /= 2
415 return norm
417 def transform(self, u, *args, axes=None, shape=None, **kwargs):
418 """
419 DCT along axes. `kwargs` will be passed on to the FFT library.
421 Args:
422 u: Data you want to transform
423 axes (tuple): Axes you want to transform along
425 Returns:
426 Data in spectral space
427 """
428 axes = axes if axes else tuple(i for i in range(u.ndim))
429 kwargs['s'] = shape
430 kwargs['norm'] = kwargs.get('norm', 'backward')
432 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs)
433 for axis in axes:
435 if self.N < trf.shape[axis]:
436 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
437 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies
438 # in the middle.
439 _trf = self.xp.zeros_like(trf)
440 N = self.N
441 N_pad = _trf.shape[axis] - N
442 end_first_half = N // 2 + 1
444 # copy first "half"
445 su = [slice(None)] * trf.ndim
446 su[axis] = slice(0, end_first_half)
447 _trf[tuple(su)] = trf[tuple(su)]
449 # copy second "half"
450 su = [slice(None)] * u.ndim
451 su[axis] = slice(end_first_half + N_pad, None)
452 s_u = [slice(None)] * u.ndim
453 s_u[axis] = slice(end_first_half, N)
454 _trf[tuple(su)] = trf[tuple(s_u)]
456 # # copy values to be cut
457 # su = [slice(None)] * u.ndim
458 # su[axis] = slice(end_first_half, end_first_half + N_pad)
459 # s_u = [slice(None)] * u.ndim
460 # s_u[axis] = slice(-N_pad, None)
461 # _trf[tuple(su)] = trf[tuple(s_u)]
463 trf = _trf
465 expansion = [np.newaxis for _ in u.shape]
466 expansion[axis] = slice(0, u.shape[axis], 1)
467 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1]
468 norm[: self.N] = self.norm
469 trf *= norm[(*expansion,)]
470 return trf
472 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
473 """
474 Inverse DCT along axis.
476 Args:
477 u: Data you want to transform
478 axes (tuple): Axes you want to transform along
480 Returns:
481 Data in physical space
482 """
483 axes = axes if axes else tuple(i for i in range(u.ndim))
484 kwargs['s'] = shape
485 kwargs['norm'] = kwargs.get('norm', 'backward')
486 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False)
488 for axis in axes:
490 if self.N == u.shape[axis]:
491 _u = u.copy()
492 else:
493 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
494 # added in the middle rather than the end. We need to resort this here and put the padding in the end.
495 N = self.N
496 _u = self.xp.zeros_like(u)
498 # copy first half
499 su = [slice(None)] * u.ndim
500 su[axis] = slice(0, N // 2 + 1)
501 _u[tuple(su)] = u[tuple(su)]
503 # copy second half
504 su = [slice(None)] * u.ndim
505 su[axis] = slice(-(N // 2), None)
506 s_u = [slice(None)] * u.ndim
507 s_u[axis] = slice(N // 2, N // 2 + (N // 2))
508 _u[tuple(s_u)] = u[tuple(su)]
510 if N % 2 == 0:
511 su = [slice(None)] * u.ndim
512 su[axis] = N // 2
513 _u[tuple(su)] *= 2
515 # generate norm
516 expansion = [np.newaxis for _ in u.shape]
517 expansion[axis] = slice(0, u.shape[axis], 1)
518 norm = self.xp.ones(_u.shape[axis])
519 norm[: self.N] = self.norm
520 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N
522 _u /= norm[(*expansion,)]
524 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs)
526 def get_BC(self, kind, **kwargs):
527 """
528 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
529 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
530 set the BC.
532 Args:
533 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
534 """
535 if kind.lower() == 'integral':
536 return self.get_integ_BC_row(**kwargs)
537 elif kind.lower() == 'dirichlet':
538 return self.get_Dirichlet_BC_row(**kwargs)
539 else:
540 return super().get_BC(kind)
542 def get_integ_BC_row(self):
543 """
544 Get a row for generating integral BCs with T polynomials.
545 It returns the values of the integrals of T polynomials over the entire interval.
547 Returns:
548 self.xp.ndarray: Row to put into a matrix
549 """
550 n = self.xp.arange(self.N) + 1
551 me = self.xp.zeros_like(n).astype(float)
552 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
553 me[0] = 2.0
554 return me
556 def get_Dirichlet_BC_row(self, x):
557 """
558 Get a row for generating Dirichlet BCs at x with T polynomials.
559 It returns the values of the T polynomials at x.
561 Args:
562 x (float): Position of the boundary condition
564 Returns:
565 self.xp.ndarray: Row to put into a matrix
566 """
567 if x == -1:
568 return (-1) ** self.xp.arange(self.N)
569 elif x == 1:
570 return self.xp.ones(self.N)
571 elif x == 0:
572 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
573 n[2::4] *= -1
574 return n
575 else:
576 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
578 def get_Dirichlet_recombination_matrix(self):
579 '''
580 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
581 This makes for a good right preconditioner.
583 Returns:
584 scipy.sparse: Sparse conversion matrix
585 '''
586 N = self.N
587 sp = self.sparse_lib
588 xp = self.xp
590 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
593class UltrasphericalHelper(ChebychevHelper):
594 """
595 This implementation follows https://doi.org/10.1137/120865458.
596 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
597 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
598 There are also conversion operators from one derivative basis to the next that are sparse.
600 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
601 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
602 """
604 def get_differentiation_matrix(self, p=1):
605 """
606 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
608 Args:
609 p (int): Order of the derivative
611 Returns:
612 sparse differentiation matrix
613 """
614 sp = self.sparse_lib
615 xp = self.xp
616 N = self.N
617 l = p
618 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
620 def get_S(self, lmbda):
621 """
622 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
623 https://doi.org/10.1137/120865458.
625 Args:
626 lmbda (int): Ingoing derivative base
628 Returns:
629 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
630 """
631 N = self.N
633 if lmbda == 0:
634 sp = scipy.sparse
635 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil()
636 mat[:, 0] *= 2
637 else:
638 sp = self.sparse_lib
639 xp = self.xp
640 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
641 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
642 )
644 return self.sparse_lib.csc_matrix(mat)
646 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
647 """
648 Get a conversion matrix from derivative base `p_in` to `p_out`.
650 Args:
651 p_out (int): Resulting derivative base
652 p_in (int): Ingoing derivative base
653 """
654 mat_fwd = self.sparse_lib.eye(self.N)
655 for i in range(min([p_in, p_out]), max([p_in, p_out])):
656 mat_fwd = self.get_S(i) @ mat_fwd
658 if p_out > p_in:
659 return mat_fwd
661 else:
662 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
663 import scipy.sparse as sp
665 if self.useGPU:
666 mat_fwd = mat_fwd.get()
668 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
670 return self.sparse_lib.csc_matrix(mat_bck)
672 def get_integration_matrix(self):
673 """
674 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
675 integration constant such that integration starts from x=-1.
677 Example:
679 .. code-block:: python
681 import numpy as np
682 from pySDC.helpers.spectral_helper import UltrasphericalHelper
684 N = 4
685 helper = UltrasphericalHelper(N)
686 coeffs = np.random.random(N)
687 coeffs[-1] = 0
689 poly = np.polynomial.Chebyshev(coeffs)
691 S = helper.get_integration_matrix()
692 U_hat = S @ coeffs
693 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
695 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
697 Returns:
698 sparse integration matrix
699 """
700 return (
701 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
702 @ self.get_basis_change_matrix(p_out=1, p_in=0)
703 * self.lin_trf_fac
704 )
706 def get_integration_constant(self, u_hat, axis):
707 """
708 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
710 Args:
711 u_hat: Solution in spectral space
712 axis: Axis you want to integrate over
714 Returns:
715 Integration constant, has one less dimension than `u_hat`
716 """
717 slices = [
718 None,
719 ] * u_hat.ndim
720 slices[axis] = slice(1, u_hat.shape[axis])
721 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
724class FFTHelper(SpectralHelper1D):
725 distributable = True
727 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
728 """
729 Constructor.
730 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
731 may choose to run on GPUs via the `useGPU` argument.
733 Args:
734 x0 (float, optional): Coordinate of left boundary
735 x1 (float, optional): Coordinate of right boundary
736 """
737 super().__init__(*args, x0=x0, x1=x1, **kwargs)
739 def get_1dgrid(self):
740 """
741 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
742 """
743 dx = self.L / self.N
744 return self.xp.arange(self.N) * dx + self.x0
746 def get_wavenumbers(self):
747 """
748 Be careful that this ordering is very unintuitive.
749 """
750 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
752 def get_differentiation_matrix(self, p=1):
753 """
754 This matrix is diagonal, allowing to invert concurrently.
756 Args:
757 p (int): Order of the derivative
759 Returns:
760 sparse differentiation matrix
761 """
762 k = self.get_wavenumbers()
764 if self.useGPU:
765 if p > 1:
766 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
767 from scipy.sparse.linalg import matrix_power
769 D = self.sparse_lib.diags(1j * k).get()
770 return self.sparse_lib.csc_matrix(matrix_power(D, p))
771 else:
772 return self.sparse_lib.diags(1j * k)
773 else:
774 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
776 def get_integration_matrix(self, p=1):
777 """
778 Get integration matrix to compute `p`-th integral over the entire domain.
780 Args:
781 p (int): Order of integral you want to compute
783 Returns:
784 sparse integration matrix
785 """
786 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
787 k[0] = 1j * self.L
788 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
790 def get_plan(self, u, forward, *args, **kwargs):
791 if self.fft_lib.__name__ == 'mpi4py_fft.fftw':
792 if 'axes' in kwargs.keys():
793 kwargs['axes'] = tuple(kwargs['axes'])
794 key = (forward, u.shape, args, *(me for me in kwargs.values()))
795 if key in self.plans.keys():
796 return self.plans[key]
797 else:
798 self.logger.debug(f'Generating FFT plan for {key=}')
799 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs)
800 self.plans[key] = transform
802 return self.plans[key]
803 else:
804 if forward:
805 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward'))
806 else:
807 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward'))
809 def transform(self, u, *args, axes=None, shape=None, **kwargs):
810 """
811 FFT along axes. `kwargs` are passed on to the FFT library.
813 Args:
814 u: Data you want to transform
815 axes (tuple): Axes you want to transform over
817 Returns:
818 transformed data
819 """
820 axes = axes if axes else tuple(i for i in range(u.ndim))
821 kwargs['s'] = shape
822 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs)
823 return plan(u, *args, axes=axes, **kwargs)
825 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
826 """
827 Inverse FFT.
829 Args:
830 u: Data you want to transform
831 axes (tuple): Axes over which to transform
833 Returns:
834 transformed data
835 """
836 axes = axes if axes else tuple(i for i in range(u.ndim))
837 kwargs['s'] = shape
838 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs)
839 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes])
841 def get_BC(self, kind):
842 """
843 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
844 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
845 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
847 Args:
848 kind ('integral' or 'nyquist'): Kind of BC
850 Returns:
851 self.xp.ndarray: Boundary condition row
852 """
853 if kind.lower() == 'integral':
854 return self.get_integ_BC_row()
855 elif kind.lower() == 'nyquist':
856 assert (
857 self.N % 2 == 0
858 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
859 BC = self.xp.zeros(self.N)
860 BC[self.get_Nyquist_mode_index()] = 1
861 return BC
862 else:
863 return super().get_BC(kind)
865 def get_Nyquist_mode_index(self):
866 """
867 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
868 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
869 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
870 after.
872 Returns:
873 int: Index of the Nyquist mode
874 """
875 k = self.get_wavenumbers()
876 Nyquist_mode = min(k)
877 return self.xp.where(k == Nyquist_mode)[0][0]
879 def get_integ_BC_row(self):
880 """
881 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
882 """
883 me = self.xp.zeros(self.N)
884 me[0] = self.L / self.N
885 return me
888class SpectralHelper:
889 """
890 This class has three functions:
891 - Easily assemble matrices containing multiple equations
892 - Direct product of 1D bases to solve problems in more dimensions
893 - Distribute the FFTs to facilitate concurrency.
895 Attributes:
896 comm (mpi4py.Intracomm): MPI communicator
897 debug (bool): Perform additional checks at extra computational cost
898 useGPU (bool): Whether to use GPUs
899 axes (list): List of 1D bases
900 components (list): List of strings of the names of components in the equations
901 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
902 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
903 BCs (sparse matrix): Matrix containing only the BCs
904 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
905 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
906 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
907 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
908 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
909 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
910 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
911 init (tuple): This is the same `init` that is used throughout the problem classes
912 init_forward (tuple): This is the equivalent of `init` in spectral space
913 """
915 xp = np
916 fft_lib = scipy.fft
917 sparse_lib = scipy.sparse
918 linalg = scipy.sparse.linalg
919 dtype = mesh
920 fft_backend = 'scipy'
921 fft_comm_backend = 'MPI'
923 @classmethod
924 def setup_GPU(cls):
925 """switch to GPU modules"""
926 import cupy as cp
927 import cupyx.scipy.sparse as sparse_lib
928 import cupyx.scipy.sparse.linalg as linalg
929 import cupyx.scipy.fft as fft_lib
930 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
932 cls.xp = cp
933 cls.sparse_lib = sparse_lib
934 cls.linalg = linalg
936 cls.fft_lib = fft_lib
937 cls.fft_backend = 'cupyx-scipy'
938 cls.fft_comm_backend = 'NCCL'
940 cls.dtype = cupy_mesh
942 @classmethod
943 def setup_CPU(cls, useFFTW=False):
944 """switch to CPU modules"""
946 cls.xp = np
947 cls.sparse_lib = scipy.sparse
948 cls.linalg = scipy.sparse.linalg
950 if useFFTW:
951 from mpi4py_fft import fftw
953 cls.fft_backend = 'fftw'
954 cls.fft_lib = fftw
955 else:
956 cls.fft_backend = 'scipy'
957 cls.fft_lib = scipy.fft
959 cls.fft_comm_backend = 'MPI'
960 cls.dtype = mesh
962 def __init__(self, comm=None, useGPU=False, debug=False):
963 """
964 Constructor
966 Args:
967 comm (mpi4py.Intracomm): MPI communicator
968 useGPU (bool): Whether to use GPUs
969 debug (bool): Perform additional checks at extra computational cost
970 """
971 self.comm = comm
972 self.debug = debug
973 self.useGPU = useGPU
975 if useGPU:
976 self.setup_GPU()
977 else:
978 self.setup_CPU()
980 self.axes = []
981 self.components = []
983 self.full_BCs = []
984 self.BC_mat = None
985 self.BCs = None
987 self.fft_cache = {}
988 self.fft_dealias_shape_cache = {}
990 self.logger = logging.getLogger(name='Spectral Discretization')
991 if debug:
992 self.logger.setLevel(logging.DEBUG)
994 @property
995 def u_init(self):
996 """
997 Get empty data container in physical space
998 """
999 return self.dtype(self.init)
1001 @property
1002 def u_init_forward(self):
1003 """
1004 Get empty data container in spectral space
1005 """
1006 return self.dtype(self.init_forward)
1008 @property
1009 def u_init_physical(self):
1010 """
1011 Get empty data container in physical space
1012 """
1013 return self.dtype(self.init_physical)
1015 @property
1016 def shape(self):
1017 """
1018 Get shape of individual solution component
1019 """
1020 return self.init[0][1:]
1022 @property
1023 def ndim(self):
1024 return len(self.axes)
1026 @property
1027 def ncomponents(self):
1028 return len(self.components)
1030 @property
1031 def V(self):
1032 """
1033 Get domain volume
1034 """
1035 return np.prod([me.L for me in self.axes])
1037 def add_axis(self, base, *args, **kwargs):
1038 """
1039 Add an axis to the domain by deciding on suitable 1D base.
1040 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
1041 bases for possible arguments.
1043 Args:
1044 base (str): 1D spectral method
1045 """
1046 kwargs['useGPU'] = self.useGPU
1048 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
1049 self.axes.append(ChebychevHelper(*args, **kwargs))
1050 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
1051 self.axes.append(FFTHelper(*args, **kwargs))
1052 elif base.lower() in ['ultraspherical', 'gegenbauer']:
1053 self.axes.append(UltrasphericalHelper(*args, **kwargs))
1054 else:
1055 raise NotImplementedError(f'{base=!r} is not implemented!')
1056 self.axes[-1].xp = self.xp
1057 self.axes[-1].sparse_lib = self.sparse_lib
1059 def add_component(self, name):
1060 """
1061 Add solution component(s).
1063 Args:
1064 name (str or list of strings): Name(s) of component(s)
1065 """
1066 if type(name) in [list, tuple]:
1067 for me in name:
1068 self.add_component(me)
1069 elif type(name) in [str]:
1070 if name in self.components:
1071 raise Exception(f'{name=!r} is already added to this problem!')
1072 self.components.append(name)
1073 else:
1074 raise NotImplementedError
1076 def index(self, name):
1077 """
1078 Get the index of component `name`.
1080 Args:
1081 name (str or list of strings): Name(s) of component(s)
1083 Returns:
1084 int: Index of the component
1085 """
1086 if type(name) in [str, int]:
1087 return self.components.index(name)
1088 elif type(name) in [list, tuple]:
1089 return (self.index(me) for me in name)
1090 else:
1091 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
1093 def get_empty_operator_matrix(self, diag=False):
1094 """
1095 Return a matrix of operators to be filled with the connections between the solution components.
1097 Args:
1098 diag (bool): Whether operator is block-diagonal
1100 Returns:
1101 list containing sparse zeros
1102 """
1103 S = len(self.components)
1104 O = self.get_Id() * 0
1105 if diag:
1106 return [O for _ in range(S)]
1107 else:
1108 return [[O for _ in range(S)] for _ in range(S)]
1110 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
1111 """
1112 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
1113 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
1114 Typically, the last line that does not contain a BC is the best choice.
1115 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1117 Args:
1118 axis (int): Axis you want to add the BC to
1119 kind (str): kind of BC, e.g. Dirichlet
1120 line (int): Line you want the BC to go in
1121 scalar (bool): Put the BC in all space positions in the other direction
1123 Returns:
1124 sparse matrix containing the BC
1125 """
1126 sp = scipy.sparse
1128 base = self.axes[axis]
1130 BC = sp.eye(base.N).tolil() * 0
1131 if self.useGPU:
1132 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1133 else:
1134 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1136 ndim = len(self.axes)
1137 if ndim == 1:
1138 return self.sparse_lib.csc_matrix(BC)
1139 elif ndim == 2:
1140 axis2 = (axis + 1) % ndim
1142 if scalar:
1143 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1144 else:
1145 _Id = self.axes[axis2].get_Id()
1147 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1149 mats = [
1150 None,
1151 ] * ndim
1152 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1153 mats[axis2] = Id
1154 return self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1155 if ndim == 3:
1156 mats = [
1157 None,
1158 ] * ndim
1160 for ax in range(ndim):
1161 if ax == axis:
1162 continue
1164 if scalar:
1165 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1)))
1166 else:
1167 _Id = self.axes[ax].get_Id()
1169 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax)
1171 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1173 return self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
1174 else:
1175 raise NotImplementedError(
1176 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
1177 )
1179 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1180 """
1181 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1182 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1183 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1185 Args:
1186 component (str): Name of the component the BC should act on
1187 equation (str): Name of the equation for the component you want to put the BC in
1188 axis (int): Axis you want to add the BC to
1189 kind (str): kind of BC, e.g. Dirichlet
1190 v: Value of the BC
1191 line (int): Line you want the BC to go in
1192 scalar (bool): Put the BC in all space positions in the other direction
1193 """
1194 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1195 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1197 if scalar:
1198 slices = [self.index(equation)] + [
1199 0,
1200 ] * self.ndim
1201 slices[axis + 1] = line
1202 else:
1203 slices = (
1204 [self.index(equation)]
1205 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1206 + [line]
1207 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1208 )
1209 N = self.axes[axis].N
1210 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]:
1211 self.BC_rhs_mask[(*slices,)] = False
1213 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1214 """
1215 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1216 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1217 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1219 Args:
1220 component (str): Name of the component the BC should act on
1221 equation (str): Name of the equation for the component you want to put the BC in
1222 axis (int): Axis you want to add the BC to
1223 kind (str): kind of BC, e.g. Dirichlet
1224 v: Value of the BC
1225 line (int): Line you want the BC to go in
1226 scalar (bool): Put the BC in all space positions in the other direction
1227 """
1228 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1229 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1230 self.full_BCs += [
1231 {
1232 'component': component,
1233 'equation': equation,
1234 'axis': axis,
1235 'kind': kind,
1236 'v': v,
1237 'line': line,
1238 'scalar': scalar,
1239 **kwargs,
1240 }
1241 ]
1243 if scalar:
1244 slices = [self.index(equation)] + [
1245 0,
1246 ] * self.ndim
1247 slices[axis + 1] = line
1248 if self.comm:
1249 if self.comm.rank == 0:
1250 self.BC_rhs_mask[(*slices,)] = True
1251 else:
1252 self.BC_rhs_mask[(*slices,)] = True
1253 else:
1254 slices = [self.index(equation), *self.global_slice(True)]
1255 N = self.axes[axis].N
1256 if (N + line) % N in self.get_indices(True)[axis]:
1257 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1258 self.BC_rhs_mask[(*slices,)] = True
1260 def setup_BCs(self):
1261 """
1262 Convert the list of lists of BCs to the boundary condition operator.
1263 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1264 condition. This method sets up a suitable sparse matrix to do this.
1265 """
1266 sp = self.sparse_lib
1267 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1268 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1270 diags = self.xp.ones(self.BCs.shape[0])
1271 diags[self.BC_zero_index] = 0
1272 self.BC_line_zero_matrix = sp.diags(diags)
1274 # prepare BCs in spectral space to easily add to the RHS
1275 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1276 self.rhs_BCs_hat = self.transform(rhs_BCs)
1278 def check_BCs(self, u):
1279 """
1280 Check that the solution satisfies the boundary conditions
1282 Args:
1283 u: The solution you want to check
1284 """
1285 assert self.ndim < 3
1286 for axis in range(self.ndim):
1287 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1289 if len(BCs) > 0:
1290 u_hat = self.transform(u, axes=(axis - self.ndim,))
1291 for BC in BCs:
1292 kwargs = {
1293 key: value
1294 for key, value in BC.items()
1295 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1296 }
1298 if axis == 0:
1299 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1300 elif axis == 1:
1301 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1302 want = BC['v']
1303 assert self.xp.allclose(
1304 get, want
1305 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1307 def put_BCs_in_matrix(self, A):
1308 """
1309 Put the boundary conditions in a matrix by replacing rows with BCs.
1310 """
1311 return self.BC_line_zero_matrix @ A + self.BCs
1313 def put_BCs_in_rhs_hat(self, rhs_hat):
1314 """
1315 Put the BCs in the right hand side in spectral space for solving.
1316 This function needs no transforms and caches a mask for faster subsequent use.
1318 Args:
1319 rhs_hat: Right hand side in spectral space
1321 Returns:
1322 rhs in spectral space with BCs
1323 """
1324 if not hasattr(self, '_rhs_hat_zero_mask'):
1325 """
1326 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1327 by the boundary conditions. The mask is then cached.
1328 """
1329 self._rhs_hat_zero_mask = self.newDistArray().astype(bool)
1331 for axis in range(self.ndim):
1332 for bc in self.full_BCs:
1333 if axis == bc['axis']:
1334 slices = [self.index(bc['equation']), *self.global_slice(True)]
1335 N = self.axes[axis].N
1336 line = bc['line']
1337 if (N + line) % N in self.get_indices(True)[axis]:
1338 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1339 self._rhs_hat_zero_mask[(*slices,)] = True
1341 rhs_hat[self._rhs_hat_zero_mask] = 0
1342 return rhs_hat + self.rhs_BCs_hat
1344 def put_BCs_in_rhs(self, rhs):
1345 """
1346 Put the BCs in the right hand side for solving.
1347 This function will transform along each axis individually and add all BCs in that axis.
1348 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1350 Args:
1351 rhs: Right hand side in physical space
1353 Returns:
1354 rhs in physical space with BCs
1355 """
1356 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1358 ndim = self.ndim
1360 for axis in range(ndim):
1361 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1363 for bc in self.full_BCs:
1365 if axis == bc['axis']:
1366 _slice = [self.index(bc['equation']), *self.global_slice(True)]
1368 N = self.axes[axis].N
1369 line = bc['line']
1370 if (N + line) % N in self.get_indices(True)[axis]:
1371 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1372 _rhs_hat[(*_slice,)] = bc['v']
1374 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1376 return rhs
1378 def add_equation_lhs(self, A, equation, relations, diag=False):
1379 """
1380 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1381 that you will convert to an operator later.
1383 Example:
1384 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1385 preconditioning:
1387 .. code-block:: python
1388 helper = SpectralHelper()
1390 helper.add_axis(base='chebychev', N=8)
1391 helper.add_component(['u', 'ux'])
1392 helper.setup_fft()
1394 I = helper.get_Id()
1395 Dx = helper.get_differentiation_matrix(axes=(0,))
1396 T2U = helper.get_basis_change_matrix('T2U')
1398 L_lhs = {
1399 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1400 'u': {'ux': -(T2U @ Dx)},
1401 }
1403 operator = helper.get_empty_operator_matrix()
1404 for line, equation in L_lhs.items():
1405 helper.add_equation_lhs(operator, line, equation)
1407 L = helper.convert_operator_matrix_to_operator(operator)
1409 Args:
1410 A (list of lists of sparse matrices): The operator to be
1411 equation (str): The equation of the component you want this in
1412 relations: (dict): Relations between quantities
1413 diag (bool): Whether operator is block-diagonal
1414 """
1415 for k, v in relations.items():
1416 if diag:
1417 assert k == equation, 'You are trying to put a non-diagonal equation into a diagonal operator'
1418 A[self.index(equation)] = v
1419 else:
1420 A[self.index(equation)][self.index(k)] = v
1422 def convert_operator_matrix_to_operator(self, M, diag=False):
1423 """
1424 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1425 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1427 Args:
1428 M (list of lists of sparse matrices): The operator to be
1430 Returns:
1431 sparse linear operator
1432 """
1433 if len(self.components) == 1:
1434 if diag:
1435 return M[0]
1436 else:
1437 return M[0][0]
1438 elif diag:
1439 return self.sparse_lib.block_diag(M, format='csc')
1440 else:
1441 return self.sparse_lib.block_array(M, format='csc')
1443 def get_wavenumbers(self):
1444 """
1445 Get grid in spectral space
1446 """
1447 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))]
1448 return self.xp.meshgrid(*grids, indexing='ij')
1450 def get_grid(self, forward_output=False):
1451 """
1452 Get grid in physical space
1453 """
1454 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1455 return self.xp.meshgrid(*grids, indexing='ij')
1457 def get_indices(self, forward_output=True):
1458 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1460 @cache
1461 def get_pfft(self, axes=None, padding=None, grid=None):
1462 if self.ndim == 1 or self.comm is None:
1463 return None
1464 from mpi4py_fft import PFFT
1466 axes = tuple(i for i in range(self.ndim)) if axes is None else axes
1467 padding = list(padding if padding else [1.0 for _ in range(self.ndim)])
1469 def no_transform(u, *args, **kwargs):
1470 return u
1472 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)}
1473 for i in axes:
1474 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform)
1476 # "transform" all axes to ensure consistent shapes.
1477 # Transform non-distributable axes last to ensure they are aligned
1478 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes))
1479 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted(
1480 [axis for axis in _axes if self.axes[axis].distributable]
1481 + [axis for axis in range(self.ndim) if axis not in _axes]
1482 )
1484 pfft = PFFT(
1485 comm=self.comm,
1486 shape=self.global_shape[1:],
1487 axes=_axes, # TODO: control the order of the transforms better
1488 dtype='D',
1489 collapse=False,
1490 backend=self.fft_backend,
1491 comm_backend=self.fft_comm_backend,
1492 padding=padding,
1493 transforms=transforms,
1494 grid=grid,
1495 )
1496 return pfft
1498 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1499 """
1500 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1502 Args:
1503 axes (tuple): Axes you want to transform over
1504 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1505 padding (tuple): Padding for dealiasing
1506 shape (tuple): Shape of the transform
1508 Returns:
1509 transform
1510 """
1511 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1512 shape = self.global_shape[1:] if shape is None else shape
1513 padding = (
1514 [
1515 1,
1516 ]
1517 * self.ndim
1518 if padding is None
1519 else padding
1520 )
1521 key = (axes, direction, tuple(padding), tuple(shape))
1523 if key not in self.fft_cache.keys():
1524 if self.comm is None:
1525 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1527 if direction == 'forward':
1528 self.fft_cache[key] = self.xp.fft.fftn
1529 elif direction == 'backward':
1530 self.fft_cache[key] = self.xp.fft.ifftn
1531 elif direction == 'object':
1532 self.fft_cache[key] = None
1533 else:
1534 if direction == 'object':
1535 from mpi4py_fft import PFFT
1537 _fft = PFFT(
1538 comm=self.comm,
1539 shape=shape,
1540 axes=sorted(axes),
1541 dtype='D',
1542 collapse=False,
1543 backend=self.fft_backend,
1544 comm_backend=self.fft_comm_backend,
1545 padding=padding,
1546 )
1547 else:
1548 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1550 if direction == 'forward':
1551 self.fft_cache[key] = _fft.forward
1552 elif direction == 'backward':
1553 self.fft_cache[key] = _fft.backward
1554 elif direction == 'object':
1555 self.fft_cache[key] = _fft
1557 return self.fft_cache[key]
1559 def local_slice(self, forward_output=True):
1560 if self.fft_obj:
1561 return self.get_pfft().local_slice(forward_output=forward_output)
1562 else:
1563 return [slice(0, me.N) for me in self.axes]
1565 def global_slice(self, forward_output=True):
1566 if self.fft_obj:
1567 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)]
1568 else:
1569 return self.local_slice(forward_output=forward_output)
1571 def setup_fft(self, real_spectral_coefficients=False):
1572 """
1573 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1574 This must also be called before setting up any BCs.
1576 Args:
1577 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1578 """
1579 if len(self.components) == 0:
1580 self.add_component('u')
1582 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1584 axes = tuple(i for i in range(len(self.axes)))
1585 self.fft_obj = self.get_pfft(axes=axes)
1587 self.init = (
1588 np.empty(shape=self.global_shape)[
1589 (
1590 ...,
1591 *self.local_slice(False),
1592 )
1593 ].shape,
1594 self.comm,
1595 np.dtype('float'),
1596 )
1597 self.init_physical = (
1598 np.empty(shape=self.global_shape)[
1599 (
1600 ...,
1601 *self.local_slice(False),
1602 )
1603 ].shape,
1604 self.comm,
1605 np.dtype('float'),
1606 )
1607 self.init_forward = (
1608 np.empty(shape=self.global_shape)[
1609 (
1610 ...,
1611 *self.local_slice(True),
1612 )
1613 ].shape,
1614 self.comm,
1615 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1616 )
1618 self.BC_mat = self.get_empty_operator_matrix()
1619 self.BC_rhs_mask = self.newDistArray().astype(bool)
1621 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False):
1622 """
1623 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but
1624 takes care of all the solution components in the tensor.
1625 """
1626 if self.comm is None:
1627 return self.xp.zeros(self.init[0], dtype=self.init[2])
1628 from mpi4py_fft.distarray import DistArray
1630 pfft = pfft if pfft else self.get_pfft()
1631 if pfft is None:
1632 if forward_output:
1633 return self.u_init_forward
1634 else:
1635 return self.u_init
1637 global_shape = pfft.global_shape(forward_output)
1638 p0 = pfft.pencil[forward_output]
1639 if forward_output is True:
1640 dtype = pfft.forward.output_array.dtype
1641 else:
1642 dtype = pfft.forward.input_array.dtype
1643 global_shape = (self.ncomponents,) * rank + global_shape
1645 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
1646 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
1647 else:
1648 darraycls = DistArray
1650 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank)
1651 return z.v if view else z
1653 def infer_alignment(self, u, forward_output, padding=None, **kwargs):
1654 if self.comm is None:
1655 return [0]
1657 def _alignment(pfft):
1658 _arr = self.newDistArray(pfft, forward_output=forward_output)
1659 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]]
1660 return _aligned_axes
1662 if padding is None:
1663 pfft = self.get_pfft(**kwargs)
1664 aligned_axes = _alignment(pfft)
1665 else:
1666 if self.ndim == 2:
1667 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)]
1668 elif self.ndim == 3:
1669 padding_options = [
1670 (1.0, 1.0, padding[2]),
1671 (1.0, padding[1], 1.0),
1672 (padding[0], 1.0, 1.0),
1673 (1.0, padding[1], padding[2]),
1674 (padding[0], 1.0, padding[2]),
1675 (padding[0], padding[1], 1.0),
1676 padding,
1677 (1.0, 1.0, 1.0),
1678 ]
1679 else:
1680 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!')
1681 for _padding in padding_options:
1682 pfft = self.get_pfft(padding=_padding, **kwargs)
1683 aligned_axes = _alignment(pfft)
1684 if len(aligned_axes) > 0:
1685 self.logger.debug(
1686 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}'
1687 )
1688 break
1690 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!'
1691 return aligned_axes
1693 def redistribute(self, u, axis, forward_output, **kwargs):
1694 if self.comm is None:
1695 return u
1697 pfft = self.get_pfft(**kwargs)
1698 _arr = self.newDistArray(pfft, forward_output=forward_output)
1700 if 'Dist' in type(u).__name__ and False:
1701 try:
1702 u.redistribute(out=_arr)
1703 return _arr
1704 except AssertionError:
1705 pass
1707 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs)
1708 for alignment in u_alignment:
1709 _arr = _arr.redistribute(alignment)
1710 if _arr.shape == u.shape:
1711 _arr[...] = u
1712 return _arr.redistribute(axis)
1714 raise Exception(
1715 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}'
1716 )
1718 def transform(self, u, *args, axes=None, padding=None, **kwargs):
1719 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1721 if pfft is None:
1722 axes = axes if axes else tuple(i for i in range(self.ndim))
1723 u_hat = u.copy()
1724 for i in axes:
1725 _axis = 1 + i if i >= 0 else i
1726 u_hat = self.axes[i].transform(u_hat, axes=(_axis,))
1727 return u_hat
1729 _in = self.newDistArray(pfft, forward_output=False, rank=1)
1730 _out = self.newDistArray(pfft, forward_output=True, rank=1)
1732 if _in.shape == u.shape:
1733 _in[...] = u
1734 else:
1735 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs)
1737 for i in range(self.ncomponents):
1738 pfft.forward(_in[i], _out[i], normalize=False)
1740 if padding is not None:
1741 _out /= np.prod(padding)
1742 return _out
1744 def itransform(self, u, *args, axes=None, padding=None, **kwargs):
1745 if padding is not None:
1746 assert all(
1747 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim)
1748 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer'
1750 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1751 if pfft is None:
1752 axes = axes if axes else tuple(i for i in range(self.ndim))
1753 u_hat = u.copy()
1754 for i in axes:
1755 _axis = 1 + i if i >= 0 else i
1756 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,))
1757 return u_hat
1759 _in = self.newDistArray(pfft, forward_output=True, rank=1)
1760 _out = self.newDistArray(pfft, forward_output=False, rank=1)
1762 if _in.shape == u.shape:
1763 _in[...] = u
1764 else:
1765 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs)
1767 for i in range(self.ncomponents):
1768 pfft.backward(_in[i], _out[i], normalize=True)
1770 if padding is not None:
1771 _out *= np.prod(padding)
1772 return _out
1774 def get_local_slice_of_1D_matrix(self, M, axis):
1775 """
1776 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1777 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can
1778 use this method to get the part corresponding to the modes carried by this rank.
1780 Args:
1781 M (sparse matrix): Global 1D matrix you want to get the local version of
1782 axis (int): Direction in which you want the local version. You will get the global matrix in other directions.
1784 Returns:
1785 sparse local matrix
1786 """
1787 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]]
1789 def expand_matrix_ND(self, matrix, aligned):
1790 sp = self.sparse_lib
1791 axes = np.delete(np.arange(self.ndim), aligned)
1792 ndim = len(axes) + 1
1794 if ndim == 1:
1795 return matrix
1796 elif ndim == 2:
1797 axis = axes[0]
1798 I1D = sp.eye(self.axes[axis].N)
1800 mats = [None] * ndim
1801 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1802 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1804 return sp.kron(*mats)
1805 elif ndim == 3:
1807 mats = [None] * ndim
1808 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1809 for axis in axes:
1810 I1D = sp.eye(self.axes[axis].N)
1811 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1813 return sp.kron(mats[0], sp.kron(*mats[1:]))
1815 else:
1816 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
1818 def get_filter_matrix(self, axis, **kwargs):
1819 """
1820 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1821 admissible.
1823 Returns:
1824 sparse bandpass matrix
1825 """
1826 if self.ndim == 1:
1827 return self.axes[0].get_filter_matrix(**kwargs)
1829 mats = [base.get_Id() for base in self.axes]
1830 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1831 return self.sparse_lib.kron(*mats)
1833 def get_differentiation_matrix(self, axes, **kwargs):
1834 """
1835 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1837 Args:
1838 axes (tuple): Axes along which to differentiate.
1840 Returns:
1841 sparse differentiation matrix
1842 """
1843 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
1844 for axis in axes[1:]:
1845 _D = self.axes[axis].get_differentiation_matrix(**kwargs)
1846 D = D @ self.expand_matrix_ND(_D, axis)
1848 return D
1850 def get_integration_matrix(self, axes):
1851 """
1852 Get integration matrix to integrate along specified axis.
1854 Args:
1855 axes (tuple): Axes along which to integrate over.
1857 Returns:
1858 sparse integration matrix
1859 """
1860 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
1861 for axis in axes[1:]:
1862 _S = self.axes[axis].get_integration_matrix()
1863 S = S @ self.expand_matrix_ND(_S, axis)
1865 return S
1867 def get_Id(self):
1868 """
1869 Get identity matrix
1871 Returns:
1872 sparse identity matrix
1873 """
1874 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
1875 for axis in range(1, self.ndim):
1876 _I = self.axes[axis].get_Id()
1877 I = I @ self.expand_matrix_ND(_I, axis)
1878 return I
1880 def get_Dirichlet_recombination_matrix(self, axis=-1):
1881 """
1882 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1884 Args:
1885 axis (int): Axis you discretized with Chebychev
1887 Returns:
1888 sparse matrix
1889 """
1890 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1891 return self.expand_matrix_ND(C1D, axis)
1893 def get_basis_change_matrix(self, axes=None, **kwargs):
1894 """
1895 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
1896 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
1898 Args:
1899 axes (tuple): Axes along which to change basis.
1901 Returns:
1902 sparse basis change matrix
1903 """
1904 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1906 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
1907 for axis in axes[1:]:
1908 _C = self.axes[axis].get_basis_change_matrix(**kwargs)
1909 C = C @ self.expand_matrix_ND(_C, axis)
1911 return C