Coverage for pySDC/helpers/spectral_helper.py: 93%
743 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 07:24 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 07:24 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
5from functools import wraps
8def cache(func):
9 """
10 Decorator for caching return values of functions.
11 This is very similar to `functools.cache`, but without the memory leaks (see
12 https://docs.astral.sh/ruff/rules/cached-instance-method/).
14 Example:
16 .. code-block:: python
18 num_calls = 0
20 @cache
21 def increment(x):
22 num_calls += 1
23 return x + 1
25 increment(0) # returns 1, num_calls = 1
26 increment(1) # returns 2, num_calls = 2
27 increment(0) # returns 1, num_calls = 2
30 Args:
31 func (function): The function you want to cache the return value of
33 Returns:
34 return value of func
35 """
36 attr_cache = f"_{func.__name__}_cache"
38 @wraps(func)
39 def wrapper(self, *args, **kwargs):
40 if not hasattr(self, attr_cache):
41 setattr(self, attr_cache, {})
43 cache = getattr(self, attr_cache)
45 key = (args, frozenset(kwargs.items()))
46 if key in cache:
47 return cache[key]
48 result = func(self, *args, **kwargs)
49 cache[key] = result
50 return result
52 return wrapper
55class SpectralHelper1D:
56 """
57 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
58 all bases need to have.
60 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
61 the code for GPUs.
63 Attributes:
64 N (int): Resolution
65 x0 (float): Coordinate of left boundary
66 x1 (float): Coordinate of right boundary
67 L (float): Length of the domain
68 useGPU (bool): Whether to use GPUs
70 """
72 fft_lib = scipy.fft
73 sparse_lib = scipy.sparse
74 linalg = scipy.sparse.linalg
75 xp = np
77 def __init__(self, N, x0=None, x1=None, useGPU=False):
78 """
79 Constructor
81 Args:
82 N (int): Resolution
83 x0 (float): Coordinate of left boundary
84 x1 (float): Coordinate of right boundary
85 useGPU (bool): Whether to use GPUs
86 """
87 self.N = N
88 self.x0 = x0
89 self.x1 = x1
90 self.L = x1 - x0
91 self.useGPU = useGPU
93 if useGPU:
94 self.setup_GPU()
96 @classmethod
97 def setup_GPU(cls):
98 """switch to GPU modules"""
99 import cupy as cp
100 import cupyx.scipy.sparse as sparse_lib
101 import cupyx.scipy.sparse.linalg as linalg
102 import cupyx.scipy.fft as fft_lib
103 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
105 cls.xp = cp
106 cls.sparse_lib = sparse_lib
107 cls.linalg = linalg
108 cls.fft_lib = fft_lib
110 def get_Id(self):
111 """
112 Get identity matrix
114 Returns:
115 sparse diagonal identity matrix
116 """
117 return self.sparse_lib.eye(self.N)
119 def get_zero(self):
120 """
121 Get a matrix with all zeros of the correct size.
123 Returns:
124 sparse matrix with zeros everywhere
125 """
126 return 0 * self.get_Id()
128 def get_differentiation_matrix(self):
129 raise NotImplementedError()
131 def get_integration_matrix(self):
132 raise NotImplementedError()
134 def get_wavenumbers(self):
135 """
136 Get the grid in spectral space
137 """
138 raise NotImplementedError
140 def get_empty_operator_matrix(self, S, O):
141 """
142 Return a matrix of operators to be filled with the connections between the solution components.
144 Args:
145 S (int): Number of components in the solution
146 O (sparse matrix): Zero matrix used for initialization
148 Returns:
149 list of lists containing sparse zeros
150 """
151 return [[O for _ in range(S)] for _ in range(S)]
153 def get_basis_change_matrix(self, *args, **kwargs):
154 """
155 Some spectral discretization change the basis during differentiation. This method can be used to transfer
156 between the various bases.
158 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
159 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
160 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
161 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
162 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
163 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
164 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
166 Returns:
167 sparse bases change matrix
168 """
169 return self.sparse_lib.eye(self.N)
171 def get_BC(self, kind):
172 """
173 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
174 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
175 the matrix are replaced by the boundary conditions as obtained by this method.
177 Args:
178 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
179 individual 1D bases for what is implemented
181 Returns:
182 self.xp.array: Boundary condition
183 """
184 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
186 def get_filter_matrix(self, kmin=0, kmax=None):
187 """
188 Get a bandpass filter.
190 Args:
191 kmin (int): Lower limit of the bandpass filter
192 kmax (int): Upper limit of the bandpass filter
194 Returns:
195 sparse matrix
196 """
198 k = abs(self.get_wavenumbers())
200 kmax = max(k) if kmax is None else kmax
202 mask = self.xp.logical_or(k >= kmax, k < kmin)
204 if self.useGPU:
205 Id = self.get_Id().get()
206 else:
207 Id = self.get_Id()
208 F = Id.tolil()
209 F[:, mask] = 0
210 return F.tocsc()
212 def get_1dgrid(self):
213 """
214 Get the grid in physical space
216 Returns:
217 self.xp.array: Grid
218 """
219 raise NotImplementedError
222class ChebychevHelper(SpectralHelper1D):
223 """
224 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
225 between physical and spectral space by discrete cosine transform.
226 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
227 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
228 be formulated in first order formulation.
230 This implementation is largely based on the Dedalus paper (arXiv:1905.10388).
231 """
233 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
234 """
235 Constructor.
236 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
237 may choose to run on GPUs via the `useGPU` argument.
239 Args:
240 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
241 use the FFT from the library to compute the DCT
242 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
243 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
244 """
245 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
246 self.lin_trf_fac = (x1 - x0) / 2
247 self.lin_trf_off = (x1 + x0) / 2
248 super().__init__(*args, x0=x0, x1=x1, **kwargs)
249 self.transform_type = transform_type
251 if self.transform_type == 'fft':
252 self.get_fft_utils()
254 self.norm = self.get_norm()
256 def get_1dgrid(self):
257 '''
258 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
259 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
260 need to do an affine transformation before any Chebychev business.
262 Returns:
263 numpy.ndarray: 1D grid
264 '''
265 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
267 def get_wavenumbers(self):
268 """Get the domain in spectral space"""
269 return self.xp.arange(self.N)
271 @cache
272 def get_conv(self, name, N=None):
273 '''
274 Get conversion matrix between different kinds of polynomials. The supported kinds are
275 - T: Chebychev polynomials of first kind
276 - U: Chebychev polynomials of second kind
277 - D: Dirichlet recombination.
279 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
280 Once generates matrices are cached. So feel free to call the method as often as you like.
282 Args:
283 name (str): Conversion code, e.g. 'T2U'
284 N (int): Size of the matrix (optional)
286 Returns:
287 scipy.sparse: Sparse conversion matrix
288 '''
289 N = N if N else self.N
290 sp = self.sparse_lib
291 xp = self.xp
293 def get_forward_conv(name):
294 if name == 'T2U':
295 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0
296 mat[:, 0] *= 2
297 elif name == 'D2T':
298 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
299 elif name[0] == name[-1]:
300 mat = self.sparse_lib.eye(self.N)
301 else:
302 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
303 return mat
305 try:
306 mat = get_forward_conv(name)
307 except NotImplementedError as E:
308 try:
309 fwd = get_forward_conv(name[::-1])
310 import scipy.sparse as sp
312 if self.sparse_lib == sp:
313 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
314 else:
315 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
316 except NotImplementedError:
317 raise NotImplementedError from E
319 return mat
321 def get_basis_change_matrix(self, conv='T2T', **kwargs):
322 """
323 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
324 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
325 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
326 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
328 Args:
329 conv (str): Conversion code, i.e. T2U
331 Returns:
332 Sparse conversion matrix
333 """
334 return self.get_conv(conv)
336 def get_integration_matrix(self, lbnd=0):
337 """
338 Get matrix for integration
340 Args:
341 lbnd (float): Lower bound for integration, only 0 is currently implemented
343 Returns:
344 Sparse integration matrix
345 """
346 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
347 n = self.xp.arange(self.N)
348 if lbnd == 0:
349 S = S.tocsc()
350 S[0, 1::2] = (
351 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
352 * (-1) ** (self.xp.arange(self.N // 2))
353 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
354 ) * self.lin_trf_fac
355 else:
356 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
357 return S
359 def get_differentiation_matrix(self, p=1):
360 '''
361 Keep in mind that the T2T differentiation matrix is dense.
363 Args:
364 p (int): Derivative you want to compute
366 Returns:
367 numpy.ndarray: Differentiation matrix
368 '''
369 D = self.xp.zeros((self.N, self.N))
370 for j in range(self.N):
371 for k in range(j):
372 D[k, j] = 2 * j * ((j - k) % 2)
374 D[0, :] /= 2
375 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
377 def get_norm(self, N=None):
378 '''
379 Get normalization for converting Chebychev coefficients and DCT
381 Args:
382 N (int, optional): Resolution
384 Returns:
385 self.xp.array: Normalization
386 '''
387 N = self.N if N is None else N
388 norm = self.xp.ones(N) / N
389 norm[0] /= 2
390 return norm
392 def get_fft_shuffle(self, forward, N):
393 """
394 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following
395 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it
396 in the complex plane. This function returns a mask to do the shuffling.
398 Args:
399 forward (bool): Whether you want the shuffle for forward transform or backward transform
400 N (int): size of the grid
402 Returns:
403 self.xp.array: Use as mask
404 """
405 xp = self.xp
406 if forward:
407 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2)
408 else:
409 mask = xp.zeros(N, dtype=int)
410 mask[: N - N % 2 : 2] = xp.arange(N // 2)
411 mask[1::2] = N - xp.arange(N // 2) - 1
412 mask[-1] = N // 2
413 return mask
415 def get_fft_shift(self, forward, N):
416 """
417 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT.
419 Args:
420 forward (bool): Whether you want the rotation for forward transform or backward transform
421 N (int): size of the grid
423 Returns:
424 self.xp.array: Rotation
425 """
426 k = self.get_wavenumbers()
427 norm = self.get_norm()
428 xp = self.xp
429 if forward:
430 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm
431 else:
432 shift = xp.exp(1j * np.pi * k / (2 * N))
433 shift[0] = 0.5
434 return shift / norm
436 def get_fft_utils(self):
437 """
438 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep
439 them cached.
440 """
441 self.fft_utils = {
442 'fwd': {},
443 'bck': {},
444 }
446 # forwards transform
447 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N)
448 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N)
450 # backwards transform
451 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N)
452 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N)
454 return self.fft_utils
456 def transform(self, u, axis=-1, **kwargs):
457 """
458 1D DCT along axis. `kwargs` will be passed on to the FFT library.
460 Args:
461 u: Data you want to transform
462 axis (int): Axis you want to transform along
464 Returns:
465 Data in spectral space
466 """
467 if self.transform_type.lower() == 'dct':
468 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm
469 elif self.transform_type.lower() == 'fft':
470 result = u.copy()
472 shuffle = [slice(0, s, 1) for s in u.shape]
473 shuffle[axis] = self.fft_utils['fwd']['shuffle']
475 v = u[(*shuffle,)]
477 V = self.fft_lib.fft(v, axis=axis, **kwargs)
479 expansion = [np.newaxis for _ in u.shape]
480 expansion[axis] = slice(0, u.shape[axis], 1)
482 V *= self.fft_utils['fwd']['shift'][(*expansion,)]
484 result.real[...] = V.real[...]
485 return result
486 else:
487 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}')
489 def itransform(self, u, axis=-1):
490 """
491 1D inverse DCT along axis.
493 Args:
494 u: Data you want to transform
495 axis (int): Axis you want to transform along
497 Returns:
498 Data in physical space
499 """
500 assert self.norm.shape[0] == u.shape[axis]
502 if self.transform_type == 'dct':
503 return self.fft_lib.idct(u / self.norm, axis=axis)
504 elif self.transform_type == 'fft':
505 result = u.copy()
507 expansion = [np.newaxis for _ in u.shape]
508 expansion[axis] = slice(0, u.shape[axis], 1)
510 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis)
512 shuffle = [slice(0, s, 1) for s in u.shape]
513 shuffle[axis] = self.fft_utils['bck']['shuffle']
514 V = v[(*shuffle,)]
516 result.real[...] = V.real[...]
517 return result
518 else:
519 raise NotImplementedError
521 def get_BC(self, kind, **kwargs):
522 """
523 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
524 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
525 set the BC.
527 Args:
528 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
529 """
530 if kind.lower() == 'integral':
531 return self.get_integ_BC_row(**kwargs)
532 elif kind.lower() == 'dirichlet':
533 return self.get_Dirichlet_BC_row(**kwargs)
534 else:
535 return super().get_BC(kind)
537 def get_integ_BC_row(self):
538 """
539 Get a row for generating integral BCs with T polynomials.
540 It returns the values of the integrals of T polynomials over the entire interval.
542 Returns:
543 self.xp.ndarray: Row to put into a matrix
544 """
545 n = self.xp.arange(self.N) + 1
546 me = self.xp.zeros_like(n).astype(float)
547 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
548 me[0] = 2.0
549 return me
551 def get_Dirichlet_BC_row(self, x):
552 """
553 Get a row for generating Dirichlet BCs at x with T polynomials.
554 It returns the values of the T polynomials at x.
556 Args:
557 x (float): Position of the boundary condition
559 Returns:
560 self.xp.ndarray: Row to put into a matrix
561 """
562 if x == -1:
563 return (-1) ** self.xp.arange(self.N)
564 elif x == 1:
565 return self.xp.ones(self.N)
566 elif x == 0:
567 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
568 n[2::4] *= -1
569 return n
570 else:
571 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
573 def get_Dirichlet_recombination_matrix(self):
574 '''
575 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
576 This makes for a good right preconditioner.
578 Returns:
579 scipy.sparse: Sparse conversion matrix
580 '''
581 N = self.N
582 sp = self.sparse_lib
583 xp = self.xp
585 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
588class UltrasphericalHelper(ChebychevHelper):
589 """
590 This implementation follows https://doi.org/10.1137/120865458.
591 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
592 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
593 There are also conversion operators from one derivative basis to the next that are sparse.
595 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
596 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
597 """
599 def get_differentiation_matrix(self, p=1):
600 """
601 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
603 Args:
604 p (int): Order of the derivative
606 Returns:
607 sparse differentiation matrix
608 """
609 sp = self.sparse_lib
610 xp = self.xp
611 N = self.N
612 l = p
613 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
615 def get_S(self, lmbda):
616 """
617 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
618 https://doi.org/10.1137/120865458.
620 Args:
621 lmbda (int): Ingoing derivative base
623 Returns:
624 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
625 """
626 N = self.N
628 if lmbda == 0:
629 sp = scipy.sparse
630 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil()
631 mat[:, 0] *= 2
632 else:
633 sp = self.sparse_lib
634 xp = self.xp
635 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
636 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
637 )
639 return self.sparse_lib.csc_matrix(mat)
641 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
642 """
643 Get a conversion matrix from derivative base `p_in` to `p_out`.
645 Args:
646 p_out (int): Resulting derivative base
647 p_in (int): Ingoing derivative base
648 """
649 mat_fwd = self.sparse_lib.eye(self.N)
650 for i in range(min([p_in, p_out]), max([p_in, p_out])):
651 mat_fwd = self.get_S(i) @ mat_fwd
653 if p_out > p_in:
654 return mat_fwd
656 else:
657 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
658 import scipy.sparse as sp
660 if self.useGPU:
661 mat_fwd = mat_fwd.get()
663 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
665 return self.sparse_lib.csc_matrix(mat_bck)
667 def get_integration_matrix(self):
668 """
669 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
670 integration constant such that integration starts from x=-1.
672 Example:
674 .. code-block:: python
676 import numpy as np
677 from pySDC.helpers.spectral_helper import UltrasphericalHelper
679 N = 4
680 helper = UltrasphericalHelper(N)
681 coeffs = np.random.random(N)
682 coeffs[-1] = 0
684 poly = np.polynomial.Chebyshev(coeffs)
686 S = helper.get_integration_matrix()
687 U_hat = S @ coeffs
688 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
690 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
692 Returns:
693 sparse integration matrix
694 """
695 return (
696 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
697 @ self.get_basis_change_matrix(p_out=1, p_in=0)
698 * self.lin_trf_fac
699 )
701 def get_integration_constant(self, u_hat, axis):
702 """
703 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
705 Args:
706 u_hat: Solution in spectral space
707 axis: Axis you want to integrate over
709 Returns:
710 Integration constant, has one less dimension than `u_hat`
711 """
712 slices = [
713 None,
714 ] * u_hat.ndim
715 slices[axis] = slice(1, u_hat.shape[axis])
716 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
719class FFTHelper(SpectralHelper1D):
720 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
721 """
722 Constructor.
723 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
724 may choose to run on GPUs via the `useGPU` argument.
726 Args:
727 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
728 use the FFT from the library to compute the DCT
729 x0 (float, optional): Coordinate of left boundary
730 x1 (float, optional): Coordinate of right boundary
731 """
732 super().__init__(*args, x0=x0, x1=x1, **kwargs)
734 def get_1dgrid(self):
735 """
736 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
737 """
738 dx = self.L / self.N
739 return self.xp.arange(self.N) * dx + self.x0
741 def get_wavenumbers(self):
742 """
743 Be careful that this ordering is very unintuitive.
744 """
745 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
747 def get_differentiation_matrix(self, p=1):
748 """
749 This matrix is diagonal, allowing to invert concurrently.
751 Args:
752 p (int): Order of the derivative
754 Returns:
755 sparse differentiation matrix
756 """
757 k = self.get_wavenumbers()
759 if self.useGPU:
760 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
761 import scipy.sparse as sp
763 D = self.sparse_lib.diags(1j * k).get()
764 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p))
765 else:
766 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
768 def get_integration_matrix(self, p=1):
769 """
770 Get integration matrix to compute `p`-th integral over the entire domain.
772 Args:
773 p (int): Order of integral you want to compute
775 Returns:
776 sparse integration matrix
777 """
778 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
779 k[0] = 1j * self.L
780 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
782 def transform(self, u, axis=-1, **kwargs):
783 """
784 1D FFT along axis. `kwargs` are passed on to the FFT library.
786 Args:
787 u: Data you want to transform
788 axis (int): Axis you want to transform along
790 Returns:
791 transformed data
792 """
793 return self.fft_lib.fft(u, axis=axis, **kwargs)
795 def itransform(self, u, axis=-1):
796 """
797 Inverse 1D FFT.
799 Args:
800 u: Data you want to transform
801 axis (int): Axis you want to transform along
803 Returns:
804 transformed data
805 """
806 return self.fft_lib.ifft(u, axis=axis)
808 def get_BC(self, kind):
809 """
810 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
811 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
812 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
814 Args:
815 kind ('integral' or 'nyquist'): Kind of BC
817 Returns:
818 self.xp.ndarray: Boundary condition row
819 """
820 if kind.lower() == 'integral':
821 return self.get_integ_BC_row()
822 elif kind.lower() == 'nyquist':
823 assert (
824 self.N % 2 == 0
825 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
826 BC = self.xp.zeros(self.N)
827 BC[self.get_Nyquist_mode_index()] = 1
828 return BC
829 else:
830 return super().get_BC(kind)
832 def get_Nyquist_mode_index(self):
833 """
834 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
835 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
836 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
837 after.
839 Returns:
840 int: Index of the Nyquist mode
841 """
842 k = self.get_wavenumbers()
843 Nyquist_mode = min(k)
844 return self.xp.where(k == Nyquist_mode)[0][0]
846 def get_integ_BC_row(self):
847 """
848 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
849 """
850 me = self.xp.zeros(self.N)
851 me[0] = self.L / self.N
852 return me
855class SpectralHelper:
856 """
857 This class has three functions:
858 - Easily assemble matrices containing multiple equations
859 - Direct product of 1D bases to solve problems in more dimensions
860 - Distribute the FFTs to facilitate concurrency.
862 Attributes:
863 comm (mpi4py.Intracomm): MPI communicator
864 debug (bool): Perform additional checks at extra computational cost
865 useGPU (bool): Whether to use GPUs
866 axes (list): List of 1D bases
867 components (list): List of strings of the names of components in the equations
868 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
869 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
870 BCs (sparse matrix): Matrix containing only the BCs
871 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
872 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
873 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
874 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
875 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
876 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
877 local_slice (slice): Local slice of the solution as in `mpi4py-fft`
878 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
879 init (tuple): This is the same `init` that is used throughout the problem classes
880 init_forward (tuple): This is the equivalent of `init` in spectral space
881 """
883 xp = np
884 fft_lib = scipy.fft
885 sparse_lib = scipy.sparse
886 linalg = scipy.sparse.linalg
887 dtype = mesh
888 fft_backend = 'fftw'
889 fft_comm_backend = 'MPI'
891 @classmethod
892 def setup_GPU(cls):
893 """switch to GPU modules"""
894 import cupy as cp
895 import cupyx.scipy.sparse as sparse_lib
896 import cupyx.scipy.sparse.linalg as linalg
897 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
899 cls.xp = cp
900 cls.sparse_lib = sparse_lib
901 cls.linalg = linalg
903 cls.fft_backend = 'cupy'
904 cls.fft_comm_backend = 'NCCL'
906 cls.dtype = cupy_mesh
908 def __init__(self, comm=None, useGPU=False, debug=False):
909 """
910 Constructor
912 Args:
913 comm (mpi4py.Intracomm): MPI communicator
914 useGPU (bool): Whether to use GPUs
915 debug (bool): Perform additional checks at extra computational cost
916 """
917 self.comm = comm
918 self.debug = debug
919 self.useGPU = useGPU
921 if useGPU:
922 self.setup_GPU()
924 self.axes = []
925 self.components = []
927 self.full_BCs = []
928 self.BC_mat = None
929 self.BCs = None
931 self.fft_cache = {}
932 self.fft_dealias_shape_cache = {}
934 @property
935 def u_init(self):
936 """
937 Get empty data container in physical space
938 """
939 return self.dtype(self.init)
941 @property
942 def u_init_forward(self):
943 """
944 Get empty data container in spectral space
945 """
946 return self.dtype(self.init_forward)
948 @property
949 def shape(self):
950 """
951 Get shape of individual solution component
952 """
953 return self.init[0][1:]
955 @property
956 def ndim(self):
957 return len(self.axes)
959 @property
960 def ncomponents(self):
961 return len(self.components)
963 @property
964 def V(self):
965 """
966 Get domain volume
967 """
968 return np.prod([me.L for me in self.axes])
970 def add_axis(self, base, *args, **kwargs):
971 """
972 Add an axis to the domain by deciding on suitable 1D base.
973 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
974 bases for possible arguments.
976 Args:
977 base (str): 1D spectral method
978 """
979 kwargs['useGPU'] = self.useGPU
981 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
982 kwargs['transform_type'] = kwargs.get('transform_type', 'fft')
983 self.axes.append(ChebychevHelper(*args, **kwargs))
984 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
985 self.axes.append(FFTHelper(*args, **kwargs))
986 elif base.lower() in ['ultraspherical', 'gegenbauer']:
987 self.axes.append(UltrasphericalHelper(*args, **kwargs))
988 else:
989 raise NotImplementedError(f'{base=!r} is not implemented!')
990 self.axes[-1].xp = self.xp
991 self.axes[-1].sparse_lib = self.sparse_lib
993 def add_component(self, name):
994 """
995 Add solution component(s).
997 Args:
998 name (str or list of strings): Name(s) of component(s)
999 """
1000 if type(name) in [list, tuple]:
1001 for me in name:
1002 self.add_component(me)
1003 elif type(name) in [str]:
1004 if name in self.components:
1005 raise Exception(f'{name=!r} is already added to this problem!')
1006 self.components.append(name)
1007 else:
1008 raise NotImplementedError
1010 def index(self, name):
1011 """
1012 Get the index of component `name`.
1014 Args:
1015 name (str or list of strings): Name(s) of component(s)
1017 Returns:
1018 int: Index of the component
1019 """
1020 if type(name) in [str, int]:
1021 return self.components.index(name)
1022 elif type(name) in [list, tuple]:
1023 return (self.index(me) for me in name)
1024 else:
1025 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
1027 def get_empty_operator_matrix(self):
1028 """
1029 Return a matrix of operators to be filled with the connections between the solution components.
1031 Returns:
1032 list containing sparse zeros
1033 """
1034 S = len(self.components)
1035 O = self.get_Id() * 0
1036 return [[O for _ in range(S)] for _ in range(S)]
1038 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
1039 """
1040 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
1041 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
1042 Typically, the last line that does not contain a BC is the best choice.
1043 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1045 Args:
1046 axis (int): Axis you want to add the BC to
1047 kind (str): kind of BC, e.g. Dirichlet
1048 line (int): Line you want the BC to go in
1049 scalar (bool): Put the BC in all space positions in the other direction
1051 Returns:
1052 sparse matrix containing the BC
1053 """
1054 sp = scipy.sparse
1056 base = self.axes[axis]
1058 BC = sp.eye(base.N).tolil() * 0
1059 if self.useGPU:
1060 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1061 else:
1062 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1064 ndim = len(self.axes)
1065 if ndim == 1:
1066 return self.sparse_lib.csc_matrix(BC)
1067 elif ndim == 2:
1068 axis2 = (axis + 1) % ndim
1070 if scalar:
1071 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1072 else:
1073 _Id = self.axes[axis2].get_Id()
1075 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1077 if self.useGPU:
1078 Id = Id.get()
1080 mats = [
1081 None,
1082 ] * ndim
1083 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1084 mats[axis2] = Id
1085 return self.sparse_lib.csc_matrix(sp.kron(*mats))
1087 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1088 """
1089 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1090 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1091 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1093 Args:
1094 component (str): Name of the component the BC should act on
1095 equation (str): Name of the equation for the component you want to put the BC in
1096 axis (int): Axis you want to add the BC to
1097 kind (str): kind of BC, e.g. Dirichlet
1098 v: Value of the BC
1099 line (int): Line you want the BC to go in
1100 scalar (bool): Put the BC in all space positions in the other direction
1101 """
1102 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1103 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1105 if scalar:
1106 slices = [self.index(equation)] + [
1107 0,
1108 ] * self.ndim
1109 slices[axis + 1] = line
1110 else:
1111 slices = (
1112 [self.index(equation)]
1113 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1114 + [line]
1115 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1116 )
1117 N = self.axes[axis].N
1118 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1119 self.BC_rhs_mask[(*slices,)] = False
1121 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1122 """
1123 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1124 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1125 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1127 Args:
1128 component (str): Name of the component the BC should act on
1129 equation (str): Name of the equation for the component you want to put the BC in
1130 axis (int): Axis you want to add the BC to
1131 kind (str): kind of BC, e.g. Dirichlet
1132 v: Value of the BC
1133 line (int): Line you want the BC to go in
1134 scalar (bool): Put the BC in all space positions in the other direction
1135 """
1136 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1137 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1138 self.full_BCs += [
1139 {
1140 'component': component,
1141 'equation': equation,
1142 'axis': axis,
1143 'kind': kind,
1144 'v': v,
1145 'line': line,
1146 'scalar': scalar,
1147 **kwargs,
1148 }
1149 ]
1151 if scalar:
1152 slices = [self.index(equation)] + [
1153 0,
1154 ] * self.ndim
1155 slices[axis + 1] = line
1156 if self.comm:
1157 if self.comm.rank == 0:
1158 self.BC_rhs_mask[(*slices,)] = True
1159 else:
1160 self.BC_rhs_mask[(*slices,)] = True
1161 else:
1162 slices = (
1163 [self.index(equation)]
1164 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1165 + [line]
1166 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1167 )
1168 N = self.axes[axis].N
1169 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1170 slices[axis + 1] -= self.local_slice[axis].start
1171 self.BC_rhs_mask[(*slices,)] = True
1173 def setup_BCs(self):
1174 """
1175 Convert the list of lists of BCs to the boundary condition operator.
1176 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1177 condition. This method sets up a suitable sparse matrix to do this.
1178 """
1179 sp = self.sparse_lib
1180 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1181 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1183 diags = self.xp.ones(self.BCs.shape[0])
1184 diags[self.BC_zero_index] = 0
1185 self.BC_line_zero_matrix = sp.diags(diags)
1187 # prepare BCs in spectral space to easily add to the RHS
1188 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1189 self.rhs_BCs_hat = self.transform(rhs_BCs)
1191 def check_BCs(self, u):
1192 """
1193 Check that the solution satisfies the boundary conditions
1195 Args:
1196 u: The solution you want to check
1197 """
1198 assert self.ndim < 3
1199 for axis in range(self.ndim):
1200 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1202 if len(BCs) > 0:
1203 u_hat = self.transform(u, axes=(axis - self.ndim,))
1204 for BC in BCs:
1205 kwargs = {
1206 key: value
1207 for key, value in BC.items()
1208 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1209 }
1211 if axis == 0:
1212 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1213 elif axis == 1:
1214 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1215 want = BC['v']
1216 assert self.xp.allclose(
1217 get, want
1218 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1220 def put_BCs_in_matrix(self, A):
1221 """
1222 Put the boundary conditions in a matrix by replacing rows with BCs.
1223 """
1224 return self.BC_line_zero_matrix @ A + self.BCs
1226 def put_BCs_in_rhs_hat(self, rhs_hat):
1227 """
1228 Put the BCs in the right hand side in spectral space for solving.
1229 This function needs no transforms and caches a mask for faster subsequent use.
1231 Args:
1232 rhs_hat: Right hand side in spectral space
1234 Returns:
1235 rhs in spectral space with BCs
1236 """
1237 if not hasattr(self, '_rhs_hat_zero_mask'):
1238 """
1239 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1240 by the boundary conditions. The mask is then cached.
1241 """
1242 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)
1244 for axis in range(self.ndim):
1245 for bc in self.full_BCs:
1246 slices = (
1247 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1248 + [bc['line']]
1249 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1250 )
1251 if axis == bc['axis']:
1252 _slice = [self.index(bc['equation'])] + slices
1253 N = self.axes[axis].N
1254 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1255 _slice[axis + 1] -= self.local_slice[axis].start
1256 self._rhs_hat_zero_mask[(*_slice,)] = True
1258 rhs_hat[self._rhs_hat_zero_mask] = 0
1259 return rhs_hat + self.rhs_BCs_hat
1261 def put_BCs_in_rhs(self, rhs):
1262 """
1263 Put the BCs in the right hand side for solving.
1264 This function will transform along each axis individually and add all BCs in that axis.
1265 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1267 Args:
1268 rhs: Right hand side in physical space
1270 Returns:
1271 rhs in physical space with BCs
1272 """
1273 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1275 ndim = self.ndim
1277 for axis in range(ndim):
1278 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1280 for bc in self.full_BCs:
1281 slices = (
1282 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1283 + [bc['line']]
1284 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1285 )
1286 if axis == bc['axis']:
1287 _slice = [self.index(bc['equation'])] + slices
1289 N = self.axes[axis].N
1290 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1291 _slice[axis + 1] -= self.local_slice[axis].start
1293 _rhs_hat[(*_slice,)] = bc['v']
1295 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1297 return rhs
1299 def add_equation_lhs(self, A, equation, relations):
1300 """
1301 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1302 that you will convert to an operator later.
1304 Example:
1305 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1306 preconditioning:
1308 .. code-block:: python
1309 helper = SpectralHelper()
1311 helper.add_axis(base='chebychev', N=8)
1312 helper.add_component(['u', 'ux'])
1313 helper.setup_fft()
1315 I = helper.get_Id()
1316 Dx = helper.get_differentiation_matrix(axes=(0,))
1317 T2U = helper.get_basis_change_matrix('T2U')
1319 L_lhs = {
1320 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1321 'u': {'ux': -(T2U @ Dx)},
1322 }
1324 operator = helper.get_empty_operator_matrix()
1325 for line, equation in L_lhs.items():
1326 helper.add_equation_lhs(operator, line, equation)
1328 L = helper.convert_operator_matrix_to_operator(operator)
1330 Args:
1331 A (list of lists of sparse matrices): The operator to be
1332 equation (str): The equation of the component you want this in
1333 relations: (dict): Relations between quantities
1334 """
1335 for k, v in relations.items():
1336 A[self.index(equation)][self.index(k)] = v
1338 def convert_operator_matrix_to_operator(self, M):
1339 """
1340 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1341 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1343 Args:
1344 M (list of lists of sparse matrices): The operator to be
1346 Returns:
1347 sparse linear operator
1348 """
1349 if len(self.components) == 1:
1350 return M[0][0]
1351 else:
1352 return self.sparse_lib.bmat(M, format='csc')
1354 def get_wavenumbers(self):
1355 """
1356 Get grid in spectral space
1357 """
1358 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))]
1359 return self.xp.meshgrid(*grids, indexing='ij')
1361 def get_grid(self):
1362 """
1363 Get grid in physical space
1364 """
1365 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))]
1366 return self.xp.meshgrid(*grids, indexing='ij')
1368 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1369 """
1370 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1372 Args:
1373 axes (tuple): Axes you want to transform over
1374 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1375 padding (tuple): Padding for dealiasing
1376 shape (tuple): Shape of the transform
1378 Returns:
1379 transform
1380 """
1381 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1382 shape = self.global_shape[1:] if shape is None else shape
1383 padding = (
1384 [
1385 1,
1386 ]
1387 * self.ndim
1388 if padding is None
1389 else padding
1390 )
1391 key = (axes, direction, tuple(padding), tuple(shape))
1393 if key not in self.fft_cache.keys():
1394 if self.comm is None:
1395 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1397 if direction == 'forward':
1398 self.fft_cache[key] = self.xp.fft.fftn
1399 elif direction == 'backward':
1400 self.fft_cache[key] = self.xp.fft.ifftn
1401 elif direction == 'object':
1402 self.fft_cache[key] = None
1403 else:
1404 if direction == 'object':
1405 from mpi4py_fft import PFFT
1407 _fft = PFFT(
1408 comm=self.comm,
1409 shape=shape,
1410 axes=sorted(axes),
1411 dtype='D',
1412 collapse=False,
1413 backend=self.fft_backend,
1414 comm_backend=self.fft_comm_backend,
1415 padding=padding,
1416 )
1417 else:
1418 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1420 if direction == 'forward':
1421 self.fft_cache[key] = _fft.forward
1422 elif direction == 'backward':
1423 self.fft_cache[key] = _fft.backward
1424 elif direction == 'object':
1425 self.fft_cache[key] = _fft
1427 return self.fft_cache[key]
1429 def setup_fft(self, real_spectral_coefficients=False):
1430 """
1431 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1432 This must also be called before setting up any BCs.
1434 Args:
1435 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1436 """
1437 if len(self.components) == 0:
1438 self.add_component('u')
1440 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1441 self.local_slice = [slice(0, me.N) for me in self.axes]
1443 axes = tuple(i for i in range(len(self.axes)))
1444 self.fft_obj = self.get_fft(axes=axes, direction='object')
1445 if self.fft_obj is not None:
1446 self.local_slice = self.fft_obj.local_slice(False)
1448 self.init = (
1449 np.empty(shape=self.global_shape)[
1450 (
1451 ...,
1452 *self.local_slice,
1453 )
1454 ].shape,
1455 self.comm,
1456 np.dtype('float'),
1457 )
1458 self.init_forward = (
1459 np.empty(shape=self.global_shape)[
1460 (
1461 ...,
1462 *self.local_slice,
1463 )
1464 ].shape,
1465 self.comm,
1466 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1467 )
1469 self.BC_mat = self.get_empty_operator_matrix()
1470 self.BC_rhs_mask = self.xp.zeros(
1471 shape=self.init[0],
1472 dtype=bool,
1473 )
1475 def _transform_fft(self, u, axes, **kwargs):
1476 """
1477 FFT along `axes`
1479 Args:
1480 u: The solution
1481 axes (tuple): Axes you want to transform over
1483 Returns:
1484 transformed solution
1485 """
1486 # TODO: clean up and try putting more of this in the 1D bases
1487 fft = self.get_fft(axes, 'forward', **kwargs)
1488 return fft(u, axes=axes)
1490 def _transform_dct(self, u, axes, padding=None, **kwargs):
1491 '''
1492 DCT along `axes`.
1493 This will only return real values!
1494 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of
1495 wavenumbers in FFTs.
1497 Args:
1498 u: The solution
1499 axes (tuple): Axes you want to transform over
1501 Returns:
1502 transformed solution
1503 '''
1504 # TODO: clean up and try putting more of this in the 1D bases
1505 if self.debug:
1506 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1508 if len(axes) > 1:
1509 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs)
1510 else:
1511 v = u.copy().astype(complex)
1512 axis = axes[0]
1513 base = self.axes[axis]
1515 shuffle = [slice(0, s, 1) for s in u.shape]
1516 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis])
1517 v = v[(*shuffle,)]
1519 if padding is not None:
1520 shape = list(v.shape)
1521 if ('forward', *padding) in self.fft_dealias_shape_cache.keys():
1522 shape[0] = self.fft_dealias_shape_cache[('forward', *padding)]
1523 elif self.comm:
1524 send_buf = np.array(v.shape[0])
1525 recv_buf = np.array(v.shape[0])
1526 self.comm.Allreduce(send_buf, recv_buf)
1527 shape[0] = int(recv_buf)
1528 fft = self.get_fft(axes, 'forward', shape=shape)
1529 else:
1530 fft = self.get_fft(axes, 'forward', **kwargs)
1532 v = fft(v, axes=axes)
1534 expansion = [np.newaxis for _ in u.shape]
1535 expansion[axis] = slice(0, v.shape[axis], 1)
1537 if padding is not None:
1538 shift = base.get_fft_shift(True, v.shape[axis])
1540 if padding[axis] != 1:
1541 N = int(np.ceil(v.shape[axis] / padding[axis]))
1542 _expansion = [slice(0, n) for n in v.shape]
1543 _expansion[axis] = slice(0, N, 1)
1544 v = v[(*_expansion,)]
1545 else:
1546 shift = base.fft_utils['fwd']['shift']
1548 v *= shift[(*expansion,)]
1550 return v.real
1552 def transform_single_component(self, u, axes=None, padding=None):
1553 """
1554 Transform a single component of the solution
1556 Args:
1557 u data to transform:
1558 axes (tuple): Axes over which to transform
1559 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1561 Returns:
1562 Transformed data
1563 """
1564 # TODO: clean up and try putting more of this in the 1D bases
1565 trfs = {
1566 ChebychevHelper: self._transform_dct,
1567 UltrasphericalHelper: self._transform_dct,
1568 FFTHelper: self._transform_fft,
1569 }
1571 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1572 padding = (
1573 [
1574 1,
1575 ]
1576 * self.ndim
1577 if padding is None
1578 else padding
1579 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only.
1581 result = u.copy().astype(complex)
1582 alignment = self.ndim - 1
1584 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1585 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1586 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1587 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)]
1589 fft = self.get_fft(axes=axes, padding=padding, direction='object')
1590 if fft is not None:
1591 shape = list(fft.global_shape(False))
1593 for trf in range(len(axes_collapsed)):
1594 _axes = axes_collapsed[trf]
1595 base = bases[trf]
1597 if len(_axes) == 0:
1598 continue
1600 for _ax in _axes:
1601 shape[_ax] = self.global_shape[1 + self.ndim + _ax]
1603 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1605 _in = self.get_aligned(
1606 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape
1607 )
1609 alignment = self.ndim + _axes[-1]
1611 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1613 if self.comm is not None:
1614 _out *= np.prod([self.axes[i].N for i in _axes])
1616 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1617 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1]
1618 result = self.get_aligned(
1619 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape
1620 )
1622 return result
1624 def transform(self, u, axes=None, padding=None):
1625 """
1626 Transform all components from physical space to spectral space
1628 Args:
1629 u data to transform:
1630 axes (tuple): Axes over which to transform
1631 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1633 Returns:
1634 Transformed data
1635 """
1636 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1637 padding = (
1638 [
1639 1,
1640 ]
1641 * self.ndim
1642 if padding is None
1643 else padding
1644 )
1646 result = [
1647 None,
1648 ] * self.ncomponents
1649 for comp in self.components:
1650 i = self.index(comp)
1652 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding)
1654 return self.xp.stack(result)
1656 def _transform_ifft(self, u, axes, **kwargs):
1657 # TODO: clean up and try putting more of this in the 1D bases
1658 ifft = self.get_fft(axes, 'backward', **kwargs)
1659 return ifft(u, axes=axes)
1661 def _transform_idct(self, u, axes, padding=None, **kwargs):
1662 '''
1663 This will only ever return real values!
1664 '''
1665 # TODO: clean up and try putting more of this in the 1D bases
1666 if self.debug:
1667 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1669 v = u.copy().astype(complex)
1671 if len(axes) > 1:
1672 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],))
1673 else:
1674 axis = axes[0]
1675 base = self.axes[axis]
1677 if padding is not None:
1678 if padding[axis] != 1:
1679 N_pad = int(np.ceil(v.shape[axis] * padding[axis]))
1680 _pad = [[0, 0] for _ in v.shape]
1681 _pad[axis] = [0, N_pad - base.N]
1682 v = self.xp.pad(v, _pad, 'constant')
1684 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N
1685 else:
1686 shift = base.fft_utils['bck']['shift']
1687 else:
1688 shift = base.fft_utils['bck']['shift']
1690 expansion = [np.newaxis for _ in u.shape]
1691 expansion[axis] = slice(0, v.shape[axis], 1)
1693 v *= shift[(*expansion,)]
1695 if padding is not None:
1696 if padding[axis] != 1:
1697 shape = list(v.shape)
1698 if ('backward', *padding) in self.fft_dealias_shape_cache.keys():
1699 shape[0] = self.fft_dealias_shape_cache[('backward', *padding)]
1700 elif self.comm:
1701 send_buf = np.array(v.shape[0])
1702 recv_buf = np.array(v.shape[0])
1703 self.comm.Allreduce(send_buf, recv_buf)
1704 shape[0] = int(recv_buf)
1705 ifft = self.get_fft(axes, 'backward', shape=shape)
1706 else:
1707 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1708 else:
1709 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1710 v = ifft(v, axes=axes)
1712 shuffle = [slice(0, s, 1) for s in v.shape]
1713 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis])
1714 v = v[(*shuffle,)]
1716 return v.real
1718 def itransform_single_component(self, u, axes=None, padding=None):
1719 """
1720 Inverse transform over single component of the solution
1722 Args:
1723 u data to transform:
1724 axes (tuple): Axes over which to transform
1725 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1727 Returns:
1728 Transformed data
1729 """
1730 # TODO: clean up and try putting more of this in the 1D bases
1731 trfs = {
1732 FFTHelper: self._transform_ifft,
1733 ChebychevHelper: self._transform_idct,
1734 UltrasphericalHelper: self._transform_idct,
1735 }
1737 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1738 padding = (
1739 [
1740 1,
1741 ]
1742 * self.ndim
1743 if padding is None
1744 else padding
1745 )
1747 result = u.copy().astype(complex)
1748 alignment = self.ndim - 1
1750 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1751 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1752 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1753 shape = list(self.global_shape[1:])
1755 for trf in range(len(axes_collapsed)):
1756 _axes = axes_collapsed[trf]
1757 base = bases[trf]
1759 if len(_axes) == 0:
1760 continue
1762 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1764 _in = self.get_aligned(
1765 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape
1766 )
1767 if self.comm is not None:
1768 _in /= np.prod([self.axes[i].N for i in _axes])
1770 alignment = self.ndim + _axes[0]
1772 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1774 for _ax in _axes:
1775 if fft:
1776 shape[_ax] = fft._input_shape[_ax]
1777 else:
1778 shape[_ax] = _out.shape[_ax]
1780 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1781 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0]
1782 result = self.get_aligned(
1783 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape
1784 )
1786 return result
1788 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
1789 """
1790 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT
1791 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include
1792 shape especially, if applicable.
1794 Args:
1795 u: The solution
1796 axis_in (int): Current alignment
1797 axis_out (int): New alignment
1798 fft (mpi4py_fft.PFFT), optional: parallel FFT object
1799 forward (bool): Whether the input is in spectral space or not
1801 Returns:
1802 solution aligned on `axis_in`
1803 """
1804 if self.comm is None or axis_in == axis_out:
1805 return u.copy()
1806 if self.comm.size == 1:
1807 return u.copy()
1809 global_fft = self.get_fft(**kwargs)
1810 axisA = [me.axisA for me in global_fft.transfer]
1811 axisB = [me.axisB for me in global_fft.transfer]
1813 current_axis = axis_in
1815 if axis_in in axisA and axis_out in axisB:
1816 while current_axis != axis_out:
1817 transfer = global_fft.transfer[axisA.index(current_axis)]
1819 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1820 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1821 arrayA[:] = u[:]
1823 transfer.forward(arrayA=arrayA, arrayB=arrayB)
1825 current_axis = transfer.axisB
1826 u = arrayB
1828 return u
1829 elif axis_in in axisB and axis_out in axisA:
1830 while current_axis != axis_out:
1831 transfer = global_fft.transfer[axisB.index(current_axis)]
1833 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1834 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1835 arrayB[:] = u[:]
1837 transfer.backward(arrayA=arrayA, arrayB=arrayB)
1839 current_axis = transfer.axisA
1840 u = arrayA
1842 return u
1843 else: # go the potentially slower route of not reusing transfer classes
1844 from mpi4py_fft import newDistArray
1846 fft = self.get_fft(**kwargs) if fft is None else fft
1848 _in = newDistArray(fft, forward).redistribute(axis_in)
1849 _in[...] = u
1851 return _in.redistribute(axis_out)
1853 def itransform(self, u, axes=None, padding=None):
1854 """
1855 Inverse transform over all components of the solution
1857 Args:
1858 u data to transform:
1859 axes (tuple): Axes over which to transform
1860 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1862 Returns:
1863 Transformed data
1864 """
1865 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1866 padding = (
1867 [
1868 1,
1869 ]
1870 * self.ndim
1871 if padding is None
1872 else padding
1873 )
1875 result = [
1876 None,
1877 ] * self.ncomponents
1878 for comp in self.components:
1879 i = self.index(comp)
1881 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding)
1883 return self.xp.stack(result)
1885 def get_local_slice_of_1D_matrix(self, M, axis):
1886 """
1887 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1888 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can
1889 use this method to get the part corresponding to the modes carried by this rank.
1891 Args:
1892 M (sparse matrix): Global 1D matrix you want to get the local version of
1893 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only.
1895 Returns:
1896 sparse local matrix
1897 """
1898 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]
1900 def expand_matrix_ND(self, matrix, aligned):
1901 sp = self.sparse_lib
1902 axes = np.delete(np.arange(self.ndim), aligned)
1903 ndim = len(axes) + 1
1905 if ndim == 1:
1906 return matrix
1907 elif ndim == 2:
1908 axis = axes[0]
1909 I1D = sp.eye(self.axes[axis].N)
1911 mats = [None] * ndim
1912 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1913 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1915 return sp.kron(*mats)
1917 else:
1918 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
1920 def get_filter_matrix(self, axis, **kwargs):
1921 """
1922 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1923 admissible.
1925 Returns:
1926 sparse bandpass matrix
1927 """
1928 if self.ndim == 1:
1929 return self.axes[0].get_filter_matrix(**kwargs)
1931 mats = [base.get_Id() for base in self.axes]
1932 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1933 return self.sparse_lib.kron(*mats)
1935 def get_differentiation_matrix(self, axes, **kwargs):
1936 """
1937 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1939 Args:
1940 axes (tuple): Axes along which to differentiate.
1942 Returns:
1943 sparse differentiation matrix
1944 """
1945 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
1946 for axis in axes[1:]:
1947 _D = self.axes[axis].get_differentiation_matrix(**kwargs)
1948 D = D @ self.expand_matrix_ND(_D, axis)
1950 return D
1952 def get_integration_matrix(self, axes):
1953 """
1954 Get integration matrix to integrate along specified axis.
1956 Args:
1957 axes (tuple): Axes along which to integrate over.
1959 Returns:
1960 sparse integration matrix
1961 """
1962 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
1963 for axis in axes[1:]:
1964 _S = self.axes[axis].get_integration_matrix()
1965 S = S @ self.expand_matrix_ND(_S, axis)
1967 return S
1969 def get_Id(self):
1970 """
1971 Get identity matrix
1973 Returns:
1974 sparse identity matrix
1975 """
1976 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
1977 for axis in range(1, self.ndim):
1978 _I = self.axes[axis].get_Id()
1979 I = I @ self.expand_matrix_ND(_I, axis)
1980 return I
1982 def get_Dirichlet_recombination_matrix(self, axis=-1):
1983 """
1984 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1986 Args:
1987 axis (int): Axis you discretized with Chebychev
1989 Returns:
1990 sparse matrix
1991 """
1992 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1993 return self.expand_matrix_ND(C1D, axis)
1995 def get_basis_change_matrix(self, axes=None, **kwargs):
1996 """
1997 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
1998 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
2000 Args:
2001 axes (tuple): Axes along which to change basis.
2003 Returns:
2004 sparse basis change matrix
2005 """
2006 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
2008 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
2009 for axis in axes[1:]:
2010 _C = self.axes[axis].get_basis_change_matrix(**kwargs)
2011 C = C @ self.expand_matrix_ND(_C, axis)
2013 return C