Coverage for pySDC/helpers/spectral_helper.py: 90%
782 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
5from functools import partial, wraps
6import logging
9def cache(func):
10 """
11 Decorator for caching return values of functions.
12 This is very similar to `functools.cache`, but without the memory leaks (see
13 https://docs.astral.sh/ruff/rules/cached-instance-method/).
15 Example:
17 .. code-block:: python
19 num_calls = 0
21 @cache
22 def increment(x):
23 num_calls += 1
24 return x + 1
26 increment(0) # returns 1, num_calls = 1
27 increment(1) # returns 2, num_calls = 2
28 increment(0) # returns 1, num_calls = 2
31 Args:
32 func (function): The function you want to cache the return value of
34 Returns:
35 return value of func
36 """
37 attr_cache = f"_{func.__name__}_cache"
39 @wraps(func)
40 def wrapper(self, *args, **kwargs):
41 if not hasattr(self, attr_cache):
42 setattr(self, attr_cache, {})
44 cache = getattr(self, attr_cache)
46 key = (args, frozenset(kwargs.items()))
47 if key in cache:
48 return cache[key]
49 result = func(self, *args, **kwargs)
50 cache[key] = result
51 return result
53 return wrapper
56class SpectralHelper1D:
57 """
58 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
59 all bases need to have.
61 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
62 the code for GPUs.
64 Attributes:
65 N (int): Resolution
66 x0 (float): Coordinate of left boundary
67 x1 (float): Coordinate of right boundary
68 L (float): Length of the domain
69 useGPU (bool): Whether to use GPUs
71 """
73 fft_lib = scipy.fft
74 sparse_lib = scipy.sparse
75 linalg = scipy.sparse.linalg
76 xp = np
77 distributable = False
79 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False):
80 """
81 Constructor
83 Args:
84 N (int): Resolution
85 x0 (float): Coordinate of left boundary
86 x1 (float): Coordinate of right boundary
87 useGPU (bool): Whether to use GPUs
88 useFFTW (bool): Whether to use FFTW for the transforms
89 """
90 self.N = N
91 self.x0 = x0
92 self.x1 = x1
93 self.L = x1 - x0
94 self.useGPU = useGPU
95 self.plans = {}
96 self.logger = logging.getLogger(name=type(self).__name__)
98 if useGPU:
99 self.setup_GPU()
100 else:
101 self.setup_CPU(useFFTW=useFFTW)
103 if useGPU and useFFTW:
104 raise ValueError('Please run either on GPUs or with FFTW, not both!')
106 @classmethod
107 def setup_GPU(cls):
108 """switch to GPU modules"""
109 import cupy as cp
110 import cupyx.scipy.sparse as sparse_lib
111 import cupyx.scipy.sparse.linalg as linalg
112 import cupyx.scipy.fft as fft_lib
113 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
115 cls.xp = cp
116 cls.sparse_lib = sparse_lib
117 cls.linalg = linalg
118 cls.fft_lib = fft_lib
120 @classmethod
121 def setup_CPU(cls, useFFTW=False):
122 """switch to CPU modules"""
124 cls.xp = np
125 cls.sparse_lib = scipy.sparse
126 cls.linalg = scipy.sparse.linalg
128 if useFFTW:
129 from mpi4py_fft import fftw
131 cls.fft_backend = 'fftw'
132 cls.fft_lib = fftw
133 else:
134 cls.fft_backend = 'scipy'
135 cls.fft_lib = scipy.fft
137 cls.fft_comm_backend = 'MPI'
138 cls.dtype = mesh
140 def get_Id(self):
141 """
142 Get identity matrix
144 Returns:
145 sparse diagonal identity matrix
146 """
147 return self.sparse_lib.eye(self.N)
149 def get_zero(self):
150 """
151 Get a matrix with all zeros of the correct size.
153 Returns:
154 sparse matrix with zeros everywhere
155 """
156 return 0 * self.get_Id()
158 def get_differentiation_matrix(self):
159 raise NotImplementedError()
161 def get_integration_matrix(self):
162 raise NotImplementedError()
164 def get_wavenumbers(self):
165 """
166 Get the grid in spectral space
167 """
168 raise NotImplementedError
170 def get_empty_operator_matrix(self, S, O):
171 """
172 Return a matrix of operators to be filled with the connections between the solution components.
174 Args:
175 S (int): Number of components in the solution
176 O (sparse matrix): Zero matrix used for initialization
178 Returns:
179 list of lists containing sparse zeros
180 """
181 return [[O for _ in range(S)] for _ in range(S)]
183 def get_basis_change_matrix(self, *args, **kwargs):
184 """
185 Some spectral discretization change the basis during differentiation. This method can be used to transfer
186 between the various bases.
188 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
189 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
190 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
191 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
192 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
193 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
194 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
196 Returns:
197 sparse bases change matrix
198 """
199 return self.sparse_lib.eye(self.N)
201 def get_BC(self, kind):
202 """
203 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
204 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
205 the matrix are replaced by the boundary conditions as obtained by this method.
207 Args:
208 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
209 individual 1D bases for what is implemented
211 Returns:
212 self.xp.array: Boundary condition
213 """
214 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
216 def get_filter_matrix(self, kmin=0, kmax=None):
217 """
218 Get a bandpass filter.
220 Args:
221 kmin (int): Lower limit of the bandpass filter
222 kmax (int): Upper limit of the bandpass filter
224 Returns:
225 sparse matrix
226 """
228 k = abs(self.get_wavenumbers())
230 kmax = max(k) if kmax is None else kmax
232 mask = self.xp.logical_or(k >= kmax, k < kmin)
234 if self.useGPU:
235 Id = self.get_Id().get()
236 else:
237 Id = self.get_Id()
238 F = Id.tolil()
239 F[:, mask] = 0
240 return F.tocsc()
242 def get_1dgrid(self):
243 """
244 Get the grid in physical space
246 Returns:
247 self.xp.array: Grid
248 """
249 raise NotImplementedError
252class ChebychevHelper(SpectralHelper1D):
253 """
254 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
255 between physical and spectral space by discrete cosine transform.
256 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
257 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
258 be formulated in first order formulation.
260 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068).
261 """
263 def __init__(self, *args, x0=-1, x1=1, **kwargs):
264 """
265 Constructor.
266 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
267 may choose to run on GPUs via the `useGPU` argument.
269 Args:
270 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
271 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
272 """
273 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
274 self.lin_trf_fac = (x1 - x0) / 2
275 self.lin_trf_off = (x1 + x0) / 2
276 super().__init__(*args, x0=x0, x1=x1, **kwargs)
278 self.norm = self.get_norm()
280 def get_1dgrid(self):
281 '''
282 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
283 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
284 need to do an affine transformation before any Chebychev business.
286 Returns:
287 numpy.ndarray: 1D grid
288 '''
289 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
291 def get_wavenumbers(self):
292 """Get the domain in spectral space"""
293 return self.xp.arange(self.N)
295 @cache
296 def get_conv(self, name, N=None):
297 '''
298 Get conversion matrix between different kinds of polynomials. The supported kinds are
299 - T: Chebychev polynomials of first kind
300 - U: Chebychev polynomials of second kind
301 - D: Dirichlet recombination.
303 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
304 Once generates matrices are cached. So feel free to call the method as often as you like.
306 Args:
307 name (str): Conversion code, e.g. 'T2U'
308 N (int): Size of the matrix (optional)
310 Returns:
311 scipy.sparse: Sparse conversion matrix
312 '''
313 N = N if N else self.N
314 sp = self.sparse_lib
316 def get_forward_conv(name):
317 if name == 'T2U':
318 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0
319 mat[:, 0] *= 2
320 elif name == 'D2T':
321 mat = sp.eye(N) - sp.eye(N, k=2)
322 elif name[0] == name[-1]:
323 mat = self.sparse_lib.eye(self.N)
324 else:
325 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
326 return mat
328 try:
329 mat = get_forward_conv(name)
330 except NotImplementedError as E:
331 try:
332 fwd = get_forward_conv(name[::-1])
333 import scipy.sparse as sp
335 if self.sparse_lib == sp:
336 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
337 else:
338 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
339 except NotImplementedError:
340 raise NotImplementedError from E
342 return mat
344 def get_basis_change_matrix(self, conv='T2T', **kwargs):
345 """
346 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
347 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
348 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
349 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
351 Args:
352 conv (str): Conversion code, i.e. T2U
354 Returns:
355 Sparse conversion matrix
356 """
357 return self.get_conv(conv)
359 def get_integration_matrix(self, lbnd=0):
360 """
361 Get matrix for integration
363 Args:
364 lbnd (float): Lower bound for integration, only 0 is currently implemented
366 Returns:
367 Sparse integration matrix
368 """
369 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
370 n = self.xp.arange(self.N)
371 if lbnd == 0:
372 S = S.tocsc()
373 S[0, 1::2] = (
374 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
375 * (-1) ** (self.xp.arange(self.N // 2))
376 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
377 ) * self.lin_trf_fac
378 else:
379 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
380 return S
382 def get_differentiation_matrix(self, p=1):
383 '''
384 Keep in mind that the T2T differentiation matrix is dense.
386 Args:
387 p (int): Derivative you want to compute
389 Returns:
390 numpy.ndarray: Differentiation matrix
391 '''
392 D = self.xp.zeros((self.N, self.N))
393 for j in range(self.N):
394 for k in range(j):
395 D[k, j] = 2 * j * ((j - k) % 2)
397 D[0, :] /= 2
398 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
400 @cache
401 def get_norm(self, N=None):
402 '''
403 Get normalization for converting Chebychev coefficients and DCT
405 Args:
406 N (int, optional): Resolution
408 Returns:
409 self.xp.array: Normalization
410 '''
411 N = self.N if N is None else N
412 norm = self.xp.ones(N) / N
413 norm[0] /= 2
414 return norm
416 def transform(self, u, *args, axes=None, shape=None, **kwargs):
417 """
418 DCT along axes. `kwargs` will be passed on to the FFT library.
420 Args:
421 u: Data you want to transform
422 axes (tuple): Axes you want to transform along
424 Returns:
425 Data in spectral space
426 """
427 axes = axes if axes else tuple(i for i in range(u.ndim))
428 kwargs['s'] = shape
429 kwargs['norm'] = kwargs.get('norm', 'backward')
431 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs)
432 for axis in axes:
434 if self.N < trf.shape[axis]:
435 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
436 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies
437 # in the middle.
438 _trf = self.xp.zeros_like(trf)
439 N = self.N
440 N_pad = _trf.shape[axis] - N
441 end_first_half = N // 2 + 1
443 # copy first "half"
444 su = [slice(None)] * trf.ndim
445 su[axis] = slice(0, end_first_half)
446 _trf[tuple(su)] = trf[tuple(su)]
448 # copy second "half"
449 su = [slice(None)] * u.ndim
450 su[axis] = slice(end_first_half + N_pad, None)
451 s_u = [slice(None)] * u.ndim
452 s_u[axis] = slice(end_first_half, N)
453 _trf[tuple(su)] = trf[tuple(s_u)]
455 # # copy values to be cut
456 # su = [slice(None)] * u.ndim
457 # su[axis] = slice(end_first_half, end_first_half + N_pad)
458 # s_u = [slice(None)] * u.ndim
459 # s_u[axis] = slice(-N_pad, None)
460 # _trf[tuple(su)] = trf[tuple(s_u)]
462 trf = _trf
464 expansion = [np.newaxis for _ in u.shape]
465 expansion[axis] = slice(0, u.shape[axis], 1)
466 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1]
467 norm[: self.N] = self.norm
468 trf *= norm[(*expansion,)]
469 return trf
471 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
472 """
473 Inverse DCT along axis.
475 Args:
476 u: Data you want to transform
477 axes (tuple): Axes you want to transform along
479 Returns:
480 Data in physical space
481 """
482 axes = axes if axes else tuple(i for i in range(u.ndim))
483 kwargs['s'] = shape
484 kwargs['norm'] = kwargs.get('norm', 'backward')
485 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False)
487 for axis in axes:
489 if self.N == u.shape[axis]:
490 _u = u.copy()
491 else:
492 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
493 # added in the middle rather than the end. We need to resort this here and put the padding in the end.
494 N = self.N
495 _u = self.xp.zeros_like(u)
497 # copy first half
498 su = [slice(None)] * u.ndim
499 su[axis] = slice(0, N // 2 + 1)
500 _u[tuple(su)] = u[tuple(su)]
502 # copy second half
503 su = [slice(None)] * u.ndim
504 su[axis] = slice(-(N // 2), None)
505 s_u = [slice(None)] * u.ndim
506 s_u[axis] = slice(N // 2, N // 2 + (N // 2))
507 _u[tuple(s_u)] = u[tuple(su)]
509 if N % 2 == 0:
510 su = [slice(None)] * u.ndim
511 su[axis] = N // 2
512 _u[tuple(su)] *= 2
514 # generate norm
515 expansion = [np.newaxis for _ in u.shape]
516 expansion[axis] = slice(0, u.shape[axis], 1)
517 norm = self.xp.ones(_u.shape[axis])
518 norm[: self.N] = self.norm
519 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N
521 _u /= norm[(*expansion,)]
523 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs)
525 def get_BC(self, kind, **kwargs):
526 """
527 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
528 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
529 set the BC.
531 Args:
532 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
533 """
534 if kind.lower() == 'integral':
535 return self.get_integ_BC_row(**kwargs)
536 elif kind.lower() == 'dirichlet':
537 return self.get_Dirichlet_BC_row(**kwargs)
538 elif kind.lower() == 'neumann':
539 return self.get_Neumann_BC_row(**kwargs)
540 else:
541 return super().get_BC(kind)
543 def get_integ_BC_row(self):
544 """
545 Get a row for generating integral BCs with T polynomials.
546 It returns the values of the integrals of T polynomials over the entire interval.
548 Returns:
549 self.xp.ndarray: Row to put into a matrix
550 """
551 n = self.xp.arange(self.N) + 1
552 me = self.xp.zeros_like(n).astype(float)
553 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
554 me[0] = 2.0
555 return me
557 def get_Dirichlet_BC_row(self, x):
558 """
559 Get a row for generating Dirichlet BCs at x with T polynomials.
560 It returns the values of the T polynomials at x.
562 Args:
563 x (float): Position of the boundary condition
565 Returns:
566 self.xp.ndarray: Row to put into a matrix
567 """
568 if x == -1:
569 return (-1) ** self.xp.arange(self.N)
570 elif x == 1:
571 return self.xp.ones(self.N)
572 elif x == 0:
573 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
574 n[2::4] *= -1
575 return n
576 else:
577 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
579 def get_Neumann_BC_row(self, x):
580 """
581 Get a row for generating Neumann BCs at x with T polynomials.
583 Args:
584 x (float): Position of the boundary condition
586 Returns:
587 self.xp.ndarray: Row to put into a matrix
588 """
589 n = self.xp.arange(self.N, dtype='D')
590 nn = n**2
591 if x == -1:
592 me = nn
593 me[1:] *= (-1) ** n[:-1]
594 return me
595 elif x == 1:
596 return nn
597 else:
598 raise NotImplementedError(f'Don\'t know how to generate Neumann BC\'s at {x=}!')
600 def get_Dirichlet_recombination_matrix(self):
601 '''
602 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
603 This makes for a good right preconditioner.
605 Returns:
606 scipy.sparse: Sparse conversion matrix
607 '''
608 N = self.N
609 sp = self.sparse_lib
611 return sp.eye(N) - sp.eye(N, k=2)
614class UltrasphericalHelper(ChebychevHelper):
615 """
616 This implementation follows https://doi.org/10.1137/120865458.
617 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
618 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
619 There are also conversion operators from one derivative basis to the next that are sparse.
621 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
622 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
623 """
625 def get_differentiation_matrix(self, p=1):
626 """
627 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
629 Args:
630 p (int): Order of the derivative
632 Returns:
633 sparse differentiation matrix
634 """
635 sp = self.sparse_lib
636 xp = self.xp
637 N = self.N
638 l = p
639 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
641 def get_S(self, lmbda):
642 """
643 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
644 https://doi.org/10.1137/120865458.
646 Args:
647 lmbda (int): Ingoing derivative base
649 Returns:
650 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
651 """
652 N = self.N
654 if lmbda == 0:
655 sp = scipy.sparse
656 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil()
657 mat[:, 0] *= 2
658 else:
659 sp = self.sparse_lib
660 xp = self.xp
661 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
662 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
663 )
665 return self.sparse_lib.csc_matrix(mat)
667 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
668 """
669 Get a conversion matrix from derivative base `p_in` to `p_out`.
671 Args:
672 p_out (int): Resulting derivative base
673 p_in (int): Ingoing derivative base
674 """
675 mat_fwd = self.sparse_lib.eye(self.N)
676 for i in range(min([p_in, p_out]), max([p_in, p_out])):
677 mat_fwd = self.get_S(i) @ mat_fwd
679 if p_out > p_in:
680 return mat_fwd
682 else:
683 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
684 import scipy.sparse as sp
686 if self.useGPU:
687 mat_fwd = mat_fwd.get()
689 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
691 return self.sparse_lib.csc_matrix(mat_bck)
693 def get_integration_matrix(self):
694 """
695 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
696 integration constant such that integration starts from x=-1.
698 Example:
700 .. code-block:: python
702 import numpy as np
703 from pySDC.helpers.spectral_helper import UltrasphericalHelper
705 N = 4
706 helper = UltrasphericalHelper(N)
707 coeffs = np.random.random(N)
708 coeffs[-1] = 0
710 poly = np.polynomial.Chebyshev(coeffs)
712 S = helper.get_integration_matrix()
713 U_hat = S @ coeffs
714 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
716 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
718 Returns:
719 sparse integration matrix
720 """
721 return (
722 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
723 @ self.get_basis_change_matrix(p_out=1, p_in=0)
724 * self.lin_trf_fac
725 )
727 def get_integration_constant(self, u_hat, axis):
728 """
729 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
731 Args:
732 u_hat: Solution in spectral space
733 axis: Axis you want to integrate over
735 Returns:
736 Integration constant, has one less dimension than `u_hat`
737 """
738 slices = [
739 None,
740 ] * u_hat.ndim
741 slices[axis] = slice(1, u_hat.shape[axis])
742 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
745class FFTHelper(SpectralHelper1D):
746 distributable = True
748 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
749 """
750 Constructor.
751 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
752 may choose to run on GPUs via the `useGPU` argument.
754 Args:
755 x0 (float, optional): Coordinate of left boundary
756 x1 (float, optional): Coordinate of right boundary
757 """
758 super().__init__(*args, x0=x0, x1=x1, **kwargs)
760 def get_1dgrid(self):
761 """
762 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
763 """
764 dx = self.L / self.N
765 return self.xp.arange(self.N) * dx + self.x0
767 def get_wavenumbers(self):
768 """
769 Be careful that this ordering is very unintuitive.
770 """
771 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
773 def get_differentiation_matrix(self, p=1):
774 """
775 This matrix is diagonal, allowing to invert concurrently.
777 Args:
778 p (int): Order of the derivative
780 Returns:
781 sparse differentiation matrix
782 """
783 k = self.get_wavenumbers()
785 if self.useGPU:
786 if p > 1:
787 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
788 from scipy.sparse.linalg import matrix_power
790 D = self.sparse_lib.diags(1j * k).get()
791 return self.sparse_lib.csc_matrix(matrix_power(D, p))
792 else:
793 return self.sparse_lib.diags(1j * k)
794 else:
795 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
797 def get_integration_matrix(self, p=1):
798 """
799 Get integration matrix to compute `p`-th integral over the entire domain.
801 Args:
802 p (int): Order of integral you want to compute
804 Returns:
805 sparse integration matrix
806 """
807 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
808 k[0] = 1j * self.L
809 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
811 def get_plan(self, u, forward, *args, **kwargs):
812 if self.fft_lib.__name__ == 'mpi4py_fft.fftw':
813 if 'axes' in kwargs.keys():
814 kwargs['axes'] = tuple(kwargs['axes'])
815 key = (forward, u.shape, args, *(me for me in kwargs.values()))
816 if key in self.plans.keys():
817 return self.plans[key]
818 else:
819 self.logger.debug(f'Generating FFT plan for {key=}')
820 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs)
821 self.plans[key] = transform
823 return self.plans[key]
824 else:
825 if forward:
826 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward'))
827 else:
828 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward'))
830 def transform(self, u, *args, axes=None, shape=None, **kwargs):
831 """
832 FFT along axes. `kwargs` are passed on to the FFT library.
834 Args:
835 u: Data you want to transform
836 axes (tuple): Axes you want to transform over
838 Returns:
839 transformed data
840 """
841 axes = axes if axes else tuple(i for i in range(u.ndim))
842 kwargs['s'] = shape
843 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs)
844 return plan(u, *args, axes=axes, **kwargs)
846 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
847 """
848 Inverse FFT.
850 Args:
851 u: Data you want to transform
852 axes (tuple): Axes over which to transform
854 Returns:
855 transformed data
856 """
857 axes = axes if axes else tuple(i for i in range(u.ndim))
858 kwargs['s'] = shape
859 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs)
860 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes])
862 def get_BC(self, kind):
863 """
864 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
865 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
866 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
868 Args:
869 kind ('integral' or 'nyquist'): Kind of BC
871 Returns:
872 self.xp.ndarray: Boundary condition row
873 """
874 if kind.lower() == 'integral':
875 return self.get_integ_BC_row()
876 elif kind.lower() == 'nyquist':
877 assert (
878 self.N % 2 == 0
879 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
880 BC = self.xp.zeros(self.N)
881 BC[self.get_Nyquist_mode_index()] = 1
882 return BC
883 else:
884 return super().get_BC(kind)
886 def get_Nyquist_mode_index(self):
887 """
888 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
889 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
890 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
891 after.
893 Returns:
894 int: Index of the Nyquist mode
895 """
896 k = self.get_wavenumbers()
897 Nyquist_mode = min(k)
898 return self.xp.where(k == Nyquist_mode)[0][0]
900 def get_integ_BC_row(self):
901 """
902 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
903 """
904 me = self.xp.zeros(self.N)
905 me[0] = self.L / self.N
906 return me
909class SpectralHelper:
910 """
911 This class has three functions:
912 - Easily assemble matrices containing multiple equations
913 - Direct product of 1D bases to solve problems in more dimensions
914 - Distribute the FFTs to facilitate concurrency.
916 Attributes:
917 comm (mpi4py.Intracomm): MPI communicator
918 debug (bool): Perform additional checks at extra computational cost
919 useGPU (bool): Whether to use GPUs
920 axes (list): List of 1D bases
921 components (list): List of strings of the names of components in the equations
922 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
923 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
924 BCs (sparse matrix): Matrix containing only the BCs
925 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
926 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
927 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
928 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
929 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
930 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
931 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
932 init (tuple): This is the same `init` that is used throughout the problem classes
933 init_forward (tuple): This is the equivalent of `init` in spectral space
934 """
936 xp = np
937 fft_lib = scipy.fft
938 sparse_lib = scipy.sparse
939 linalg = scipy.sparse.linalg
940 dtype = mesh
941 fft_backend = 'scipy'
942 fft_comm_backend = 'MPI'
944 @classmethod
945 def setup_GPU(cls):
946 """switch to GPU modules"""
947 import cupy as cp
948 import cupyx.scipy.sparse as sparse_lib
949 import cupyx.scipy.sparse.linalg as linalg
950 import cupyx.scipy.fft as fft_lib
951 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
953 cls.xp = cp
954 cls.sparse_lib = sparse_lib
955 cls.linalg = linalg
957 cls.fft_lib = fft_lib
958 cls.fft_backend = 'cupyx-scipy'
959 cls.fft_comm_backend = 'NCCL'
961 cls.dtype = cupy_mesh
963 @classmethod
964 def setup_CPU(cls, useFFTW=False):
965 """switch to CPU modules"""
967 cls.xp = np
968 cls.sparse_lib = scipy.sparse
969 cls.linalg = scipy.sparse.linalg
971 if useFFTW:
972 from mpi4py_fft import fftw
974 cls.fft_backend = 'fftw'
975 cls.fft_lib = fftw
976 else:
977 cls.fft_backend = 'scipy'
978 cls.fft_lib = scipy.fft
980 cls.fft_comm_backend = 'MPI'
981 cls.dtype = mesh
983 def __init__(self, comm=None, useGPU=False, debug=False):
984 """
985 Constructor
987 Args:
988 comm (mpi4py.Intracomm): MPI communicator
989 useGPU (bool): Whether to use GPUs
990 debug (bool): Perform additional checks at extra computational cost
991 """
992 self.comm = comm
993 self.debug = debug
994 self.useGPU = useGPU
996 if useGPU:
997 self.setup_GPU()
998 else:
999 self.setup_CPU()
1001 self.axes = []
1002 self.components = []
1004 self.full_BCs = []
1005 self.BC_mat = None
1006 self.BCs = None
1008 self.fft_cache = {}
1009 self.fft_dealias_shape_cache = {}
1011 self.logger = logging.getLogger(name='Spectral Discretization')
1012 if debug:
1013 self.logger.setLevel(logging.DEBUG)
1015 @property
1016 def u_init(self):
1017 """
1018 Get empty data container in physical space
1019 """
1020 return self.dtype(self.init)
1022 @property
1023 def u_init_forward(self):
1024 """
1025 Get empty data container in spectral space
1026 """
1027 return self.dtype(self.init_forward)
1029 @property
1030 def u_init_physical(self):
1031 """
1032 Get empty data container in physical space
1033 """
1034 return self.dtype(self.init_physical)
1036 @property
1037 def shape(self):
1038 """
1039 Get shape of individual solution component
1040 """
1041 return self.init[0][1:]
1043 @property
1044 def ndim(self):
1045 return len(self.axes)
1047 @property
1048 def ncomponents(self):
1049 return len(self.components)
1051 @property
1052 def V(self):
1053 """
1054 Get domain volume
1055 """
1056 return np.prod([me.L for me in self.axes])
1058 def add_axis(self, base, *args, **kwargs):
1059 """
1060 Add an axis to the domain by deciding on suitable 1D base.
1061 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
1062 bases for possible arguments.
1064 Args:
1065 base (str): 1D spectral method
1066 """
1067 kwargs['useGPU'] = self.useGPU
1069 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
1070 self.axes.append(ChebychevHelper(*args, **kwargs))
1071 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
1072 self.axes.append(FFTHelper(*args, **kwargs))
1073 elif base.lower() in ['ultraspherical', 'gegenbauer']:
1074 self.axes.append(UltrasphericalHelper(*args, **kwargs))
1075 else:
1076 raise NotImplementedError(f'{base=!r} is not implemented!')
1077 self.axes[-1].xp = self.xp
1078 self.axes[-1].sparse_lib = self.sparse_lib
1080 def add_component(self, name):
1081 """
1082 Add solution component(s).
1084 Args:
1085 name (str or list of strings): Name(s) of component(s)
1086 """
1087 if type(name) in [list, tuple]:
1088 for me in name:
1089 self.add_component(me)
1090 elif type(name) in [str]:
1091 if name in self.components:
1092 raise Exception(f'{name=!r} is already added to this problem!')
1093 self.components.append(name)
1094 else:
1095 raise NotImplementedError
1097 def index(self, name):
1098 """
1099 Get the index of component `name`.
1101 Args:
1102 name (str or list of strings): Name(s) of component(s)
1104 Returns:
1105 int: Index of the component
1106 """
1107 if type(name) in [str, int]:
1108 return self.components.index(name)
1109 elif type(name) in [list, tuple]:
1110 return (self.index(me) for me in name)
1111 else:
1112 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
1114 def get_empty_operator_matrix(self, diag=False):
1115 """
1116 Return a matrix of operators to be filled with the connections between the solution components.
1118 Args:
1119 diag (bool): Whether operator is block-diagonal
1121 Returns:
1122 list containing sparse zeros
1123 """
1124 S = len(self.components)
1125 O = self.get_Id() * 0
1126 if diag:
1127 return [O for _ in range(S)]
1128 else:
1129 return [[O for _ in range(S)] for _ in range(S)]
1131 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
1132 """
1133 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
1134 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
1135 Typically, the last line that does not contain a BC is the best choice.
1136 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1138 Args:
1139 axis (int): Axis you want to add the BC to
1140 kind (str): kind of BC, e.g. Dirichlet
1141 line (int): Line you want the BC to go in
1142 scalar (bool): Put the BC in all space positions in the other direction
1144 Returns:
1145 sparse matrix containing the BC
1146 """
1147 sp = scipy.sparse
1149 base = self.axes[axis]
1151 BC = sp.eye(base.N).tolil() * 0
1152 if self.useGPU:
1153 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1154 else:
1155 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1157 ndim = len(self.axes)
1158 if ndim == 1:
1159 mat = self.sparse_lib.csc_matrix(BC)
1160 elif ndim == 2:
1161 axis2 = (axis + 1) % ndim
1163 if scalar:
1164 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1165 else:
1166 _Id = self.axes[axis2].get_Id()
1168 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1170 mats = [
1171 None,
1172 ] * ndim
1173 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1174 mats[axis2] = Id
1175 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1176 elif ndim == 3:
1177 mats = [
1178 None,
1179 ] * ndim
1181 for ax in range(ndim):
1182 if ax == axis:
1183 continue
1185 if scalar:
1186 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1)))
1187 else:
1188 _Id = self.axes[ax].get_Id()
1190 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax)
1192 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1194 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
1195 else:
1196 raise NotImplementedError(
1197 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
1198 )
1199 mat = self.eliminate_zeros(mat)
1200 return mat
1202 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1203 """
1204 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1205 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1206 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1208 Args:
1209 component (str): Name of the component the BC should act on
1210 equation (str): Name of the equation for the component you want to put the BC in
1211 axis (int): Axis you want to add the BC to
1212 kind (str): kind of BC, e.g. Dirichlet
1213 v: Value of the BC
1214 line (int): Line you want the BC to go in
1215 scalar (bool): Put the BC in all space positions in the other direction
1216 """
1217 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1218 _BC = self.eliminate_zeros(_BC)
1219 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1221 if scalar:
1222 slices = [self.index(equation)] + [
1223 0,
1224 ] * self.ndim
1225 slices[axis + 1] = line
1226 else:
1227 slices = (
1228 [self.index(equation)]
1229 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1230 + [line]
1231 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1232 )
1233 N = self.axes[axis].N
1234 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]:
1235 self.BC_rhs_mask[(*slices,)] = False
1237 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1238 """
1239 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1240 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1241 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1243 Args:
1244 component (str): Name of the component the BC should act on
1245 equation (str): Name of the equation for the component you want to put the BC in
1246 axis (int): Axis you want to add the BC to
1247 kind (str): kind of BC, e.g. Dirichlet
1248 v: Value of the BC
1249 line (int): Line you want the BC to go in
1250 scalar (bool): Put the BC in all space positions in the other direction
1251 """
1252 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1253 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1254 self.full_BCs += [
1255 {
1256 'component': component,
1257 'equation': equation,
1258 'axis': axis,
1259 'kind': kind,
1260 'v': v,
1261 'line': line,
1262 'scalar': scalar,
1263 **kwargs,
1264 }
1265 ]
1267 if scalar:
1268 slices = [self.index(equation)] + [
1269 0,
1270 ] * self.ndim
1271 slices[axis + 1] = line
1272 if self.comm:
1273 if self.comm.rank == 0:
1274 self.BC_rhs_mask[(*slices,)] = True
1275 else:
1276 self.BC_rhs_mask[(*slices,)] = True
1277 else:
1278 slices = [self.index(equation), *self.global_slice(True)]
1279 N = self.axes[axis].N
1280 if (N + line) % N in self.get_indices(True)[axis]:
1281 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1282 self.BC_rhs_mask[(*slices,)] = True
1284 def setup_BCs(self):
1285 """
1286 Convert the list of lists of BCs to the boundary condition operator.
1287 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1288 condition. This method sets up a suitable sparse matrix to do this.
1289 """
1290 sp = self.sparse_lib
1291 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1292 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1294 diags = self.xp.ones(self.BCs.shape[0])
1295 diags[self.BC_zero_index] = 0
1296 self.BC_line_zero_matrix = sp.diags(diags)
1298 # prepare BCs in spectral space to easily add to the RHS
1299 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1300 self.rhs_BCs_hat = self.transform(rhs_BCs)
1302 def check_BCs(self, u):
1303 """
1304 Check that the solution satisfies the boundary conditions
1306 Args:
1307 u: The solution you want to check
1308 """
1309 assert self.ndim < 3
1310 for axis in range(self.ndim):
1311 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1313 if len(BCs) > 0:
1314 u_hat = self.transform(u, axes=(axis - self.ndim,))
1315 for BC in BCs:
1316 kwargs = {
1317 key: value
1318 for key, value in BC.items()
1319 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1320 }
1322 if axis == 0:
1323 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1324 elif axis == 1:
1325 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1326 want = BC['v']
1327 assert self.xp.allclose(
1328 get, want
1329 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1331 def put_BCs_in_matrix(self, A):
1332 """
1333 Put the boundary conditions in a matrix by replacing rows with BCs.
1334 """
1335 return self.BC_line_zero_matrix @ A + self.BCs
1337 def put_BCs_in_rhs_hat(self, rhs_hat):
1338 """
1339 Put the BCs in the right hand side in spectral space for solving.
1340 This function needs no transforms and caches a mask for faster subsequent use.
1342 Args:
1343 rhs_hat: Right hand side in spectral space
1345 Returns:
1346 rhs in spectral space with BCs
1347 """
1348 if not hasattr(self, '_rhs_hat_zero_mask'):
1349 """
1350 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1351 by the boundary conditions. The mask is then cached.
1352 """
1353 self._rhs_hat_zero_mask = self.newDistArray().astype(bool)
1355 for axis in range(self.ndim):
1356 for bc in self.full_BCs:
1357 if axis == bc['axis']:
1358 slices = [self.index(bc['equation']), *self.global_slice(True)]
1359 N = self.axes[axis].N
1360 line = bc['line']
1361 if (N + line) % N in self.get_indices(True)[axis]:
1362 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1363 self._rhs_hat_zero_mask[(*slices,)] = True
1365 rhs_hat[self._rhs_hat_zero_mask] = 0
1366 return rhs_hat + self.rhs_BCs_hat
1368 def put_BCs_in_rhs(self, rhs):
1369 """
1370 Put the BCs in the right hand side for solving.
1371 This function will transform along each axis individually and add all BCs in that axis.
1372 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1374 Args:
1375 rhs: Right hand side in physical space
1377 Returns:
1378 rhs in physical space with BCs
1379 """
1380 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1382 ndim = self.ndim
1384 for axis in range(ndim):
1385 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1387 for bc in self.full_BCs:
1389 if axis == bc['axis']:
1390 _slice = [self.index(bc['equation']), *self.global_slice(True)]
1392 N = self.axes[axis].N
1393 line = bc['line']
1394 if (N + line) % N in self.get_indices(True)[axis]:
1395 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1396 _rhs_hat[(*_slice,)] = bc['v']
1398 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1400 return rhs
1402 def add_equation_lhs(self, A, equation, relations):
1403 """
1404 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1405 that you will convert to an operator later.
1407 Example:
1408 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1409 preconditioning:
1411 .. code-block:: python
1412 helper = SpectralHelper()
1414 helper.add_axis(base='chebychev', N=8)
1415 helper.add_component(['u', 'ux'])
1416 helper.setup_fft()
1418 I = helper.get_Id()
1419 Dx = helper.get_differentiation_matrix(axes=(0,))
1420 T2U = helper.get_basis_change_matrix('T2U')
1422 L_lhs = {
1423 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1424 'u': {'ux': -(T2U @ Dx)},
1425 }
1427 operator = helper.get_empty_operator_matrix()
1428 for line, equation in L_lhs.items():
1429 helper.add_equation_lhs(operator, line, equation)
1431 L = helper.convert_operator_matrix_to_operator(operator)
1433 Args:
1434 A (list of lists of sparse matrices): The operator to be
1435 equation (str): The equation of the component you want this in
1436 relations: (dict): Relations between quantities
1437 """
1438 for k, v in relations.items():
1439 A[self.index(equation)][self.index(k)] = v
1441 def eliminate_zeros(self, A):
1442 """
1443 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
1444 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
1445 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.
1447 Args:
1448 A: sparse matrix to be pruned
1450 Returns:
1451 CSC sparse matrix
1452 """
1453 if self.useGPU:
1454 A = A.get()
1455 A = A.tocsc()
1456 A.eliminate_zeros()
1457 if self.useGPU:
1458 A = self.sparse_lib.csc_matrix(A)
1459 return A
1461 def convert_operator_matrix_to_operator(self, M):
1462 """
1463 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1464 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1466 Args:
1467 M (list of lists of sparse matrices): The operator to be
1469 Returns:
1470 sparse linear operator
1471 """
1472 if len(self.components) == 1:
1473 op = M[0][0]
1474 else:
1475 op = self.sparse_lib.bmat(M, format='csc')
1477 op = self.eliminate_zeros(op)
1478 return op
1480 def get_wavenumbers(self):
1481 """
1482 Get grid in spectral space
1483 """
1484 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))]
1485 return self.xp.meshgrid(*grids, indexing='ij')
1487 def get_grid(self, forward_output=False):
1488 """
1489 Get grid in physical space
1490 """
1491 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1492 return self.xp.meshgrid(*grids, indexing='ij')
1494 def get_indices(self, forward_output=True):
1495 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1497 @cache
1498 def get_pfft(self, axes=None, padding=None, grid=None):
1499 if self.ndim == 1 or self.comm is None:
1500 return None
1501 from mpi4py_fft import PFFT
1503 axes = tuple(i for i in range(self.ndim)) if axes is None else axes
1504 padding = list(padding if padding else [1.0 for _ in range(self.ndim)])
1506 def no_transform(u, *args, **kwargs):
1507 return u
1509 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)}
1510 for i in axes:
1511 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform)
1513 # "transform" all axes to ensure consistent shapes.
1514 # Transform non-distributable axes last to ensure they are aligned
1515 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes))
1516 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted(
1517 [axis for axis in _axes if self.axes[axis].distributable]
1518 + [axis for axis in range(self.ndim) if axis not in _axes]
1519 )
1521 pfft = PFFT(
1522 comm=self.comm,
1523 shape=self.global_shape[1:],
1524 axes=_axes, # TODO: control the order of the transforms better
1525 dtype='D',
1526 collapse=False,
1527 backend=self.fft_backend,
1528 comm_backend=self.fft_comm_backend,
1529 padding=padding,
1530 transforms=transforms,
1531 grid=grid,
1532 )
1533 return pfft
1535 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1536 """
1537 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1539 Args:
1540 axes (tuple): Axes you want to transform over
1541 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1542 padding (tuple): Padding for dealiasing
1543 shape (tuple): Shape of the transform
1545 Returns:
1546 transform
1547 """
1548 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1549 shape = self.global_shape[1:] if shape is None else shape
1550 padding = (
1551 [
1552 1,
1553 ]
1554 * self.ndim
1555 if padding is None
1556 else padding
1557 )
1558 key = (axes, direction, tuple(padding), tuple(shape))
1560 if key not in self.fft_cache.keys():
1561 if self.comm is None:
1562 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1564 if direction == 'forward':
1565 self.fft_cache[key] = self.xp.fft.fftn
1566 elif direction == 'backward':
1567 self.fft_cache[key] = self.xp.fft.ifftn
1568 elif direction == 'object':
1569 self.fft_cache[key] = None
1570 else:
1571 if direction == 'object':
1572 from mpi4py_fft import PFFT
1574 _fft = PFFT(
1575 comm=self.comm,
1576 shape=shape,
1577 axes=sorted(axes),
1578 dtype='D',
1579 collapse=False,
1580 backend=self.fft_backend,
1581 comm_backend=self.fft_comm_backend,
1582 padding=padding,
1583 )
1584 else:
1585 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1587 if direction == 'forward':
1588 self.fft_cache[key] = _fft.forward
1589 elif direction == 'backward':
1590 self.fft_cache[key] = _fft.backward
1591 elif direction == 'object':
1592 self.fft_cache[key] = _fft
1594 return self.fft_cache[key]
1596 def local_slice(self, forward_output=True):
1597 if self.fft_obj:
1598 return self.get_pfft().local_slice(forward_output=forward_output)
1599 else:
1600 return [slice(0, me.N) for me in self.axes]
1602 def global_slice(self, forward_output=True):
1603 if self.fft_obj:
1604 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)]
1605 else:
1606 return self.local_slice(forward_output=forward_output)
1608 def setup_fft(self, real_spectral_coefficients=False):
1609 """
1610 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1611 This must also be called before setting up any BCs.
1613 Args:
1614 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1615 """
1616 if len(self.components) == 0:
1617 self.add_component('u')
1619 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1621 axes = tuple(i for i in range(len(self.axes)))
1622 self.fft_obj = self.get_pfft(axes=axes)
1624 self.init = (
1625 np.empty(shape=self.global_shape)[
1626 (
1627 ...,
1628 *self.local_slice(False),
1629 )
1630 ].shape,
1631 self.comm,
1632 np.dtype('float'),
1633 )
1634 self.init_physical = (
1635 np.empty(shape=self.global_shape)[
1636 (
1637 ...,
1638 *self.local_slice(False),
1639 )
1640 ].shape,
1641 self.comm,
1642 np.dtype('float'),
1643 )
1644 self.init_forward = (
1645 np.empty(shape=self.global_shape)[
1646 (
1647 ...,
1648 *self.local_slice(True),
1649 )
1650 ].shape,
1651 self.comm,
1652 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1653 )
1655 self.BC_mat = self.get_empty_operator_matrix()
1656 self.BC_rhs_mask = self.newDistArray().astype(bool)
1658 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False):
1659 """
1660 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but
1661 takes care of all the solution components in the tensor.
1662 """
1663 if self.comm is None:
1664 return self.xp.zeros(self.init[0], dtype=self.init[2])
1665 from mpi4py_fft.distarray import DistArray
1667 pfft = pfft if pfft else self.get_pfft()
1668 if pfft is None:
1669 if forward_output:
1670 return self.u_init_forward
1671 else:
1672 return self.u_init
1674 global_shape = pfft.global_shape(forward_output)
1675 p0 = pfft.pencil[forward_output]
1676 if forward_output is True:
1677 dtype = pfft.forward.output_array.dtype
1678 else:
1679 dtype = pfft.forward.input_array.dtype
1680 global_shape = (self.ncomponents,) * rank + global_shape
1682 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
1683 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
1684 else:
1685 darraycls = DistArray
1687 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank)
1688 return z.v if view else z
1690 def infer_alignment(self, u, forward_output, padding=None, **kwargs):
1691 if self.comm is None:
1692 return [0]
1694 def _alignment(pfft):
1695 _arr = self.newDistArray(pfft, forward_output=forward_output)
1696 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]]
1697 return _aligned_axes
1699 if padding is None:
1700 pfft = self.get_pfft(**kwargs)
1701 aligned_axes = _alignment(pfft)
1702 else:
1703 if self.ndim == 2:
1704 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)]
1705 elif self.ndim == 3:
1706 padding_options = [
1707 (1.0, 1.0, padding[2]),
1708 (1.0, padding[1], 1.0),
1709 (padding[0], 1.0, 1.0),
1710 (1.0, padding[1], padding[2]),
1711 (padding[0], 1.0, padding[2]),
1712 (padding[0], padding[1], 1.0),
1713 padding,
1714 (1.0, 1.0, 1.0),
1715 ]
1716 else:
1717 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!')
1718 for _padding in padding_options:
1719 pfft = self.get_pfft(padding=_padding, **kwargs)
1720 aligned_axes = _alignment(pfft)
1721 if len(aligned_axes) > 0:
1722 self.logger.debug(
1723 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}'
1724 )
1725 break
1727 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!'
1728 return aligned_axes
1730 def redistribute(self, u, axis, forward_output, **kwargs):
1731 if self.comm is None:
1732 return u
1734 pfft = self.get_pfft(**kwargs)
1735 _arr = self.newDistArray(pfft, forward_output=forward_output)
1737 if 'Dist' in type(u).__name__ and False:
1738 try:
1739 u.redistribute(out=_arr)
1740 return _arr
1741 except AssertionError:
1742 pass
1744 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs)
1745 for alignment in u_alignment:
1746 _arr = _arr.redistribute(alignment)
1747 if _arr.shape == u.shape:
1748 _arr[...] = u
1749 return _arr.redistribute(axis)
1751 raise Exception(
1752 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}'
1753 )
1755 def transform(self, u, *args, axes=None, padding=None, **kwargs):
1756 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1758 if pfft is None:
1759 axes = axes if axes else tuple(i for i in range(self.ndim))
1760 u_hat = u.copy()
1761 for i in axes:
1762 _axis = 1 + i if i >= 0 else i
1763 u_hat = self.axes[i].transform(u_hat, axes=(_axis,))
1764 return u_hat
1766 _in = self.newDistArray(pfft, forward_output=False, rank=1)
1767 _out = self.newDistArray(pfft, forward_output=True, rank=1)
1769 if _in.shape == u.shape:
1770 _in[...] = u
1771 else:
1772 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs)
1774 for i in range(self.ncomponents):
1775 pfft.forward(_in[i], _out[i], normalize=False)
1777 if padding is not None:
1778 _out /= np.prod(padding)
1779 return _out
1781 def itransform(self, u, *args, axes=None, padding=None, **kwargs):
1782 if padding is not None:
1783 assert all(
1784 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim)
1785 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer'
1787 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1788 if pfft is None:
1789 axes = axes if axes else tuple(i for i in range(self.ndim))
1790 u_hat = u.copy()
1791 for i in axes:
1792 _axis = 1 + i if i >= 0 else i
1793 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,))
1794 return u_hat
1796 _in = self.newDistArray(pfft, forward_output=True, rank=1)
1797 _out = self.newDistArray(pfft, forward_output=False, rank=1)
1799 if _in.shape == u.shape:
1800 _in[...] = u
1801 else:
1802 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs)
1804 for i in range(self.ncomponents):
1805 pfft.backward(_in[i], _out[i], normalize=True)
1807 if padding is not None:
1808 _out *= np.prod(padding)
1809 return _out
1811 def get_local_slice_of_1D_matrix(self, M, axis):
1812 """
1813 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1814 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can
1815 use this method to get the part corresponding to the modes carried by this rank.
1817 Args:
1818 M (sparse matrix): Global 1D matrix you want to get the local version of
1819 axis (int): Direction in which you want the local version. You will get the global matrix in other directions.
1821 Returns:
1822 sparse local matrix
1823 """
1824 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]]
1826 def expand_matrix_ND(self, matrix, aligned):
1827 sp = self.sparse_lib
1828 axes = np.delete(np.arange(self.ndim), aligned)
1829 ndim = len(axes) + 1
1831 if ndim == 1:
1832 mat = matrix
1833 elif ndim == 2:
1834 axis = axes[0]
1835 I1D = sp.eye(self.axes[axis].N)
1837 mats = [None] * ndim
1838 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1839 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1841 mat = sp.kron(*mats)
1842 elif ndim == 3:
1844 mats = [None] * ndim
1845 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1846 for axis in axes:
1847 I1D = sp.eye(self.axes[axis].N)
1848 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1850 mat = sp.kron(mats[0], sp.kron(*mats[1:]))
1852 else:
1853 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
1855 mat = self.eliminate_zeros(mat)
1856 return mat
1858 def get_filter_matrix(self, axis, **kwargs):
1859 """
1860 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1861 admissible.
1863 Returns:
1864 sparse bandpass matrix
1865 """
1866 if self.ndim == 1:
1867 return self.axes[0].get_filter_matrix(**kwargs)
1869 mats = [base.get_Id() for base in self.axes]
1870 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1871 return self.sparse_lib.kron(*mats)
1873 def get_differentiation_matrix(self, axes, **kwargs):
1874 """
1875 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1877 Args:
1878 axes (tuple): Axes along which to differentiate.
1880 Returns:
1881 sparse differentiation matrix
1882 """
1883 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
1884 for axis in axes[1:]:
1885 _D = self.axes[axis].get_differentiation_matrix(**kwargs)
1886 D = D @ self.expand_matrix_ND(_D, axis)
1888 return D
1890 def get_integration_matrix(self, axes):
1891 """
1892 Get integration matrix to integrate along specified axis.
1894 Args:
1895 axes (tuple): Axes along which to integrate over.
1897 Returns:
1898 sparse integration matrix
1899 """
1900 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
1901 for axis in axes[1:]:
1902 _S = self.axes[axis].get_integration_matrix()
1903 S = S @ self.expand_matrix_ND(_S, axis)
1905 return S
1907 def get_Id(self):
1908 """
1909 Get identity matrix
1911 Returns:
1912 sparse identity matrix
1913 """
1914 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
1915 for axis in range(1, self.ndim):
1916 _I = self.axes[axis].get_Id()
1917 I = I @ self.expand_matrix_ND(_I, axis)
1918 return I
1920 def get_Dirichlet_recombination_matrix(self, axis=-1):
1921 """
1922 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1924 Args:
1925 axis (int): Axis you discretized with Chebychev
1927 Returns:
1928 sparse matrix
1929 """
1930 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1931 return self.expand_matrix_ND(C1D, axis)
1933 def get_basis_change_matrix(self, axes=None, **kwargs):
1934 """
1935 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
1936 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
1938 Args:
1939 axes (tuple): Axes along which to change basis.
1941 Returns:
1942 sparse basis change matrix
1943 """
1944 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1946 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
1947 for axis in axes[1:]:
1948 _C = self.axes[axis].get_basis_change_matrix(**kwargs)
1949 C = C @ self.expand_matrix_ND(_C, axis)
1951 return C