Coverage for pySDC/helpers/spectral_helper.py: 92%
778 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
7class SpectralHelper1D:
8 """
9 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
10 all bases need to have.
12 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
13 the code for GPUs.
15 Attributes:
16 N (int): Resolution
17 x0 (float): Coordinate of left boundary
18 x1 (float): Coordinate of right boundary
19 L (float): Length of the domain
20 useGPU (bool): Whether to use GPUs
22 """
24 fft_lib = scipy.fft
25 sparse_lib = scipy.sparse
26 linalg = scipy.sparse.linalg
27 xp = np
29 def __init__(self, N, x0=None, x1=None, useGPU=False):
30 """
31 Constructor
33 Args:
34 N (int): Resolution
35 x0 (float): Coordinate of left boundary
36 x1 (float): Coordinate of right boundary
37 useGPU (bool): Whether to use GPUs
38 """
39 self.N = N
40 self.x0 = x0
41 self.x1 = x1
42 self.L = x1 - x0
43 self.useGPU = useGPU
45 if useGPU:
46 self.setup_GPU()
48 @classmethod
49 def setup_GPU(cls):
50 """switch to GPU modules"""
51 import cupy as cp
52 import cupyx.scipy.sparse as sparse_lib
53 import cupyx.scipy.sparse.linalg as linalg
54 import cupyx.scipy.fft as fft_lib
55 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
57 cls.xp = cp
58 cls.sparse_lib = sparse_lib
59 cls.linalg = linalg
60 cls.fft_lib = fft_lib
62 def get_Id(self):
63 """
64 Get identity matrix
66 Returns:
67 sparse diagonal identity matrix
68 """
69 return self.sparse_lib.eye(self.N)
71 def get_zero(self):
72 """
73 Get a matrix with all zeros of the correct size.
75 Returns:
76 sparse matrix with zeros everywhere
77 """
78 return 0 * self.get_Id()
80 def get_differentiation_matrix(self):
81 raise NotImplementedError()
83 def get_integration_matrix(self):
84 raise NotImplementedError()
86 def get_wavenumbers(self):
87 """
88 Get the grid in spectral space
89 """
90 raise NotImplementedError
92 def get_empty_operator_matrix(self, S, O):
93 """
94 Return a matrix of operators to be filled with the connections between the solution components.
96 Args:
97 S (int): Number of components in the solution
98 O (sparse matrix): Zero matrix used for initialization
100 Returns:
101 list of lists containing sparse zeros
102 """
103 return [[O for _ in range(S)] for _ in range(S)]
105 def get_basis_change_matrix(self, *args, **kwargs):
106 """
107 Some spectral discretization change the basis during differentiation. This method can be used to transfer
108 between the various bases.
110 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
111 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
112 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
113 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
114 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
115 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
116 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
118 Returns:
119 sparse bases change matrix
120 """
121 return self.sparse_lib.eye(self.N)
123 def get_BC(self, kind):
124 """
125 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
126 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
127 the matrix are replaced by the boundary conditions as obtained by this method.
129 Args:
130 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
131 individual 1D bases for what is implemented
133 Returns:
134 self.xp.array: Boundary condition
135 """
136 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
138 def get_filter_matrix(self, kmin=0, kmax=None):
139 """
140 Get a bandpass filter.
142 Args:
143 kmin (int): Lower limit of the bandpass filter
144 kmax (int): Upper limit of the bandpass filter
146 Returns:
147 sparse matrix
148 """
150 k = abs(self.get_wavenumbers())
152 kmax = max(k) if kmax is None else kmax
154 mask = self.xp.logical_or(k >= kmax, k < kmin)
156 if self.useGPU:
157 Id = self.get_Id().get()
158 else:
159 Id = self.get_Id()
160 F = Id.tolil()
161 F[:, mask] = 0
162 return F.tocsc()
164 def get_1dgrid(self):
165 """
166 Get the grid in physical space
168 Returns:
169 self.xp.array: Grid
170 """
171 raise NotImplementedError
174class ChebychevHelper(SpectralHelper1D):
175 """
176 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
177 between physical and spectral space by discrete cosine transform.
178 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
179 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
180 be formulated in first order formulation.
182 This implementation is largely based on the Dedalus paper (arXiv:1905.10388).
183 """
185 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
186 """
187 Constructor.
188 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
189 may choose to run on GPUs via the `useGPU` argument.
191 Args:
192 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
193 use the FFT from the library to compute the DCT
194 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
195 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
196 """
197 assert x0 == -1
198 assert x1 == 1
199 super().__init__(*args, x0=x0, x1=x1, **kwargs)
200 self.transform_type = transform_type
202 if self.transform_type == 'fft':
203 self.get_fft_utils()
205 self.cache = {}
206 self.norm = self.get_norm()
208 def get_1dgrid(self):
209 '''
210 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
211 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
212 need to do an affine transformation before any Chebychev business.
214 Returns:
215 numpy.ndarray: 1D grid
216 '''
217 return self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5))
219 def get_wavenumbers(self):
220 """Get the domain in spectral space"""
221 return self.xp.arange(self.N)
223 def get_conv(self, name, N=None):
224 '''
225 Get conversion matrix between different kinds of polynomials. The supported kinds are
226 - T: Chebychev polynomials of first kind
227 - U: Chebychev polynomials of second kind
228 - D: Dirichlet recombination.
230 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
231 Once generates matrices are cached. So feel free to call the method as often as you like.
233 Args:
234 name (str): Conversion code, e.g. 'T2U'
235 N (int): Size of the matrix (optional)
237 Returns:
238 scipy.sparse: Sparse conversion matrix
239 '''
240 if name in self.cache.keys() and not N:
241 return self.cache[name]
243 N = N if N else self.N
244 sp = self.sparse_lib
245 xp = self.xp
247 def get_forward_conv(name):
248 if name == 'T2U':
249 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0
250 mat[:, 0] *= 2
251 elif name == 'D2T':
252 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
253 elif name[0] == name[-1]:
254 mat = self.sparse_lib.eye(self.N)
255 else:
256 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
257 return mat
259 try:
260 mat = get_forward_conv(name)
261 except NotImplementedError as E:
262 try:
263 fwd = get_forward_conv(name[::-1])
264 import scipy.sparse as sp
266 if self.sparse_lib == sp:
267 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
268 else:
269 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
270 except NotImplementedError:
271 raise NotImplementedError from E
273 self.cache[name] = mat
274 return mat
276 def get_basis_change_matrix(self, conv='T2T', **kwargs):
277 """
278 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
279 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
280 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
281 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
283 Args:
284 conv (str): Conversion code, i.e. T2U
286 Returns:
287 Sparse conversion matrix
288 """
289 return self.get_conv(conv)
291 def get_integration_matrix(self, lbnd=0):
292 """
293 Get matrix for integration
295 Args:
296 lbnd (float): Lower bound for integration, only 0 is currently implemented
298 Returns:
299 Sparse integration matrix
300 """
301 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
302 n = self.xp.arange(self.N)
303 if lbnd == 0:
304 S = S.tocsc()
305 S[0, 1::2] = (
306 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
307 * (-1) ** (self.xp.arange(self.N // 2))
308 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
309 )
310 else:
311 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
312 return S
314 def get_differentiation_matrix(self, p=1):
315 '''
316 Keep in mind that the T2T differentiation matrix is dense.
318 Args:
319 p (int): Derivative you want to compute
321 Returns:
322 numpy.ndarray: Differentiation matrix
323 '''
324 D = self.xp.zeros((self.N, self.N))
325 for j in range(self.N):
326 for k in range(j):
327 D[k, j] = 2 * j * ((j - k) % 2)
329 D[0, :] /= 2
330 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p))
332 def get_norm(self, N=None):
333 '''
334 Get normalization for converting Chebychev coefficients and DCT
336 Args:
337 N (int, optional): Resolution
339 Returns:
340 self.xp.array: Normalization
341 '''
342 N = self.N if N is None else N
343 norm = self.xp.ones(N) / N
344 norm[0] /= 2
345 return norm
347 def get_fft_shuffle(self, forward, N):
348 """
349 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following
350 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it
351 in the complex plane. This function returns a mask to do the shuffling.
353 Args:
354 forward (bool): Whether you want the shuffle for forward transform or backward transform
355 N (int): size of the grid
357 Returns:
358 self.xp.array: Use as mask
359 """
360 xp = self.xp
361 if forward:
362 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2)
363 else:
364 mask = xp.zeros(N, dtype=int)
365 mask[: N - N % 2 : 2] = xp.arange(N // 2)
366 mask[1::2] = N - xp.arange(N // 2) - 1
367 mask[-1] = N // 2
368 return mask
370 def get_fft_shift(self, forward, N):
371 """
372 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT.
374 Args:
375 forward (bool): Whether you want the rotation for forward transform or backward transform
376 N (int): size of the grid
378 Returns:
379 self.xp.array: Rotation
380 """
381 k = self.get_wavenumbers()
382 norm = self.get_norm()
383 xp = self.xp
384 if forward:
385 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm
386 else:
387 shift = xp.exp(1j * np.pi * k / (2 * N))
388 shift[0] = 0.5
389 return shift / norm
391 def get_fft_utils(self):
392 """
393 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep
394 them cached.
395 """
396 self.fft_utils = {
397 'fwd': {},
398 'bck': {},
399 }
401 # forwards transform
402 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N)
403 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N)
405 # backwards transform
406 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N)
407 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N)
409 return self.fft_utils
411 def transform(self, u, axis=-1, **kwargs):
412 """
413 1D DCT along axis. `kwargs` will be passed on to the FFT library.
415 Args:
416 u: Data you want to transform
417 axis (int): Axis you want to transform along
419 Returns:
420 Data in spectral space
421 """
422 if self.transform_type.lower() == 'dct':
423 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm
424 elif self.transform_type.lower() == 'fft':
425 result = u.copy()
427 shuffle = [slice(0, s, 1) for s in u.shape]
428 shuffle[axis] = self.fft_utils['fwd']['shuffle']
430 v = u[(*shuffle,)]
432 V = self.fft_lib.fft(v, axis=axis, **kwargs)
434 expansion = [np.newaxis for _ in u.shape]
435 expansion[axis] = slice(0, u.shape[axis], 1)
437 V *= self.fft_utils['fwd']['shift'][(*expansion,)]
439 result.real[...] = V.real[...]
440 return result
441 else:
442 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}')
444 def itransform(self, u, axis=-1):
445 """
446 1D inverse DCT along axis.
448 Args:
449 u: Data you want to transform
450 axis (int): Axis you want to transform along
452 Returns:
453 Data in physical space
454 """
455 assert self.norm.shape[0] == u.shape[axis]
457 if self.transform_type == 'dct':
458 return self.fft_lib.idct(u / self.norm, axis=axis)
459 elif self.transform_type == 'fft':
460 result = u.copy()
462 expansion = [np.newaxis for _ in u.shape]
463 expansion[axis] = slice(0, u.shape[axis], 1)
465 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis)
467 shuffle = [slice(0, s, 1) for s in u.shape]
468 shuffle[axis] = self.fft_utils['bck']['shuffle']
469 V = v[(*shuffle,)]
471 result.real[...] = V.real[...]
472 return result
473 else:
474 raise NotImplementedError
476 def get_BC(self, kind, **kwargs):
477 """
478 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
479 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
480 set the BC.
482 Args:
483 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
484 """
485 if kind.lower() == 'integral':
486 return self.get_integ_BC_row(**kwargs)
487 elif kind.lower() == 'dirichlet':
488 return self.get_Dirichlet_BC_row(**kwargs)
489 else:
490 return super().get_BC(kind)
492 def get_integ_BC_row(self):
493 """
494 Get a row for generating integral BCs with T polynomials.
495 It returns the values of the integrals of T polynomials over the entire interval.
497 Returns:
498 self.xp.ndarray: Row to put into a matrix
499 """
500 n = self.xp.arange(self.N) + 1
501 me = self.xp.zeros_like(n).astype(float)
502 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
503 me[0] = 2.0
504 return me
506 def get_Dirichlet_BC_row(self, x):
507 """
508 Get a row for generating Dirichlet BCs at x with T polynomials.
509 It returns the values of the T polynomials at x.
511 Args:
512 x (float): Position of the boundary condition
514 Returns:
515 self.xp.ndarray: Row to put into a matrix
516 """
517 if x == -1:
518 return (-1) ** self.xp.arange(self.N)
519 elif x == 1:
520 return self.xp.ones(self.N)
521 elif x == 0:
522 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
523 n[2::4] *= -1
524 return n
525 else:
526 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
528 def get_Dirichlet_recombination_matrix(self):
529 '''
530 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
531 This makes for a good right preconditioner.
533 Returns:
534 scipy.sparse: Sparse conversion matrix
535 '''
536 N = self.N
537 sp = self.sparse_lib
538 xp = self.xp
540 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
543class UltrasphericalHelper(ChebychevHelper):
544 """
545 This implementation follows https://doi.org/10.1137/120865458.
546 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
547 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
548 There are also conversion operators from one derivative basis to the next that are sparse.
550 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
551 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
552 """
554 def get_differentiation_matrix(self, p=1):
555 """
556 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
558 Args:
559 p (int): Order of the derivative
561 Returns:
562 sparse differentiation matrix
563 """
564 sp = self.sparse_lib
565 xp = self.xp
566 N = self.N
567 l = p
568 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l)
570 def get_S(self, lmbda):
571 """
572 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
573 https://doi.org/10.1137/120865458.
575 Args:
576 lmbda (int): Ingoing derivative base
578 Returns:
579 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
580 """
581 N = self.N
583 if lmbda == 0:
584 sp = scipy.sparse
585 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil()
586 mat[:, 0] *= 2
587 else:
588 sp = self.sparse_lib
589 xp = self.xp
590 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
591 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
592 )
594 return self.sparse_lib.csc_matrix(mat)
596 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
597 """
598 Get a conversion matrix from derivative base `p_in` to `p_out`.
600 Args:
601 p_out (int): Resulting derivative base
602 p_in (int): Ingoing derivative base
603 """
604 mat_fwd = self.sparse_lib.eye(self.N)
605 for i in range(min([p_in, p_out]), max([p_in, p_out])):
606 mat_fwd = self.get_S(i) @ mat_fwd
608 if p_out > p_in:
609 return mat_fwd
611 else:
612 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
613 import scipy.sparse as sp
615 if self.useGPU:
616 mat_fwd = mat_fwd.get()
618 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
620 return self.sparse_lib.csc_matrix(mat_bck)
622 def get_integration_matrix(self):
623 """
624 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
625 integration constant such that integration starts from x=-1.
627 Example:
629 .. code-block:: python
631 import numpy as np
632 from pySDC.helpers.spectral_helper import UltrasphericalHelper
634 N = 4
635 helper = UltrasphericalHelper(N)
636 coeffs = np.random.random(N)
637 coeffs[-1] = 0
639 poly = np.polynomial.Chebyshev(coeffs)
641 S = helper.get_integration_matrix()
642 U_hat = S @ coeffs
643 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
645 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
647 Returns:
648 sparse integration matrix
649 """
650 return self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_basis_change_matrix(
651 p_out=1, p_in=0
652 )
654 def get_integration_constant(self, u_hat, axis):
655 """
656 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
658 Args:
659 u_hat: Solution in spectral space
660 axis: Axis you want to integrate over
662 Returns:
663 Integration constant, has one less dimension than `u_hat`
664 """
665 slices = [
666 None,
667 ] * u_hat.ndim
668 slices[axis] = slice(1, u_hat.shape[axis])
669 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
672class FFTHelper(SpectralHelper1D):
673 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
674 """
675 Constructor.
676 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
677 may choose to run on GPUs via the `useGPU` argument.
679 Args:
680 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
681 use the FFT from the library to compute the DCT
682 x0 (float, optional): Coordinate of left boundary
683 x1 (float, optional): Coordinate of right boundary
684 """
685 super().__init__(*args, x0=x0, x1=x1, **kwargs)
687 def get_1dgrid(self):
688 """
689 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
690 """
691 dx = self.L / self.N
692 return self.xp.arange(self.N) * dx + self.x0
694 def get_wavenumbers(self):
695 """
696 Be careful that this ordering is very unintuitive.
697 """
698 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
700 def get_differentiation_matrix(self, p=1):
701 """
702 This matrix is diagonal, allowing to invert concurrently.
704 Args:
705 p (int): Order of the derivative
707 Returns:
708 sparse differentiation matrix
709 """
710 k = self.get_wavenumbers()
712 if self.useGPU:
713 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
714 import scipy.sparse as sp
716 D = self.sparse_lib.diags(1j * k).get()
717 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p))
718 else:
719 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
721 def get_integration_matrix(self, p=1):
722 """
723 Get integration matrix to compute `p`-th integral over the entire domain.
725 Args:
726 p (int): Order of integral you want to compute
728 Returns:
729 sparse integration matrix
730 """
731 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
732 k[0] = 1j * self.L
733 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
735 def transform(self, u, axis=-1, **kwargs):
736 """
737 1D FFT along axis. `kwargs` are passed on to the FFT library.
739 Args:
740 u: Data you want to transform
741 axis (int): Axis you want to transform along
743 Returns:
744 transformed data
745 """
746 return self.fft_lib.fft(u, axis=axis, **kwargs)
748 def itransform(self, u, axis=-1):
749 """
750 Inverse 1D FFT.
752 Args:
753 u: Data you want to transform
754 axis (int): Axis you want to transform along
756 Returns:
757 transformed data
758 """
759 return self.fft_lib.ifft(u, axis=axis)
761 def get_BC(self, kind):
762 """
763 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
764 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
765 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
767 Args:
768 kind ('integral' or 'nyquist'): Kind of BC
770 Returns:
771 self.xp.ndarray: Boundary condition row
772 """
773 if kind.lower() == 'integral':
774 return self.get_integ_BC_row()
775 elif kind.lower() == 'nyquist':
776 assert (
777 self.N % 2 == 0
778 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
779 BC = self.xp.zeros(self.N)
780 BC[self.get_Nyquist_mode_index()] = 1
781 return BC
782 else:
783 return super().get_BC(kind)
785 def get_Nyquist_mode_index(self):
786 """
787 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
788 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
789 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
790 after.
792 Returns:
793 int: Index of the Nyquist mode
794 """
795 k = self.get_wavenumbers()
796 Nyquist_mode = min(k)
797 return self.xp.where(k == Nyquist_mode)[0][0]
799 def get_integ_BC_row(self):
800 """
801 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
802 """
803 me = self.xp.zeros(self.N)
804 me[0] = self.L / self.N
805 return me
808class SpectralHelper:
809 """
810 This class has three functions:
811 - Easily assemble matrices containing multiple equations
812 - Direct product of 1D bases to solve problems in more dimensions
813 - Distribute the FFTs to facilitate concurrency.
815 Attributes:
816 comm (mpi4py.Intracomm): MPI communicator
817 debug (bool): Perform additional checks at extra computational cost
818 useGPU (bool): Whether to use GPUs
819 axes (list): List of 1D bases
820 components (list): List of strings of the names of components in the equations
821 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
822 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
823 BCs (sparse matrix): Matrix containing only the BCs
824 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
825 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
826 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
827 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
828 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
829 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
830 local_slice (slice): Local slice of the solution as in `mpi4py-fft`
831 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
832 init (tuple): This is the same `init` that is used throughout the problem classes
833 init_forward (tuple): This is the equivalent of `init` in spectral space
834 """
836 xp = np
837 fft_lib = scipy.fft
838 sparse_lib = scipy.sparse
839 linalg = scipy.sparse.linalg
840 dtype = mesh
841 fft_backend = 'fftw'
842 fft_comm_backend = 'MPI'
844 @classmethod
845 def setup_GPU(cls):
846 """switch to GPU modules"""
847 import cupy as cp
848 import cupyx.scipy.sparse as sparse_lib
849 import cupyx.scipy.sparse.linalg as linalg
850 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
852 cls.xp = cp
853 cls.sparse_lib = sparse_lib
854 cls.linalg = linalg
856 cls.fft_backend = 'cupy'
857 cls.fft_comm_backend = 'NCCL'
859 cls.dtype = cupy_mesh
861 def __init__(self, comm=None, useGPU=False, debug=False):
862 """
863 Constructor
865 Args:
866 comm (mpi4py.Intracomm): MPI communicator
867 useGPU (bool): Whether to use GPUs
868 debug (bool): Perform additional checks at extra computational cost
869 """
870 self.comm = comm
871 self.debug = debug
872 self.useGPU = useGPU
874 if useGPU:
875 self.setup_GPU()
877 self.axes = []
878 self.components = []
880 self.full_BCs = []
881 self.BC_mat = None
882 self.BCs = None
884 self.fft_cache = {}
885 self.fft_dealias_shape_cache = {}
887 @property
888 def u_init(self):
889 """
890 Get empty data container in physical space
891 """
892 return self.dtype(self.init)
894 @property
895 def u_init_forward(self):
896 """
897 Get empty data container in spectral space
898 """
899 return self.dtype(self.init_forward)
901 @property
902 def shape(self):
903 """
904 Get shape of individual solution component
905 """
906 return self.init[0][1:]
908 @property
909 def ndim(self):
910 return len(self.axes)
912 @property
913 def ncomponents(self):
914 return len(self.components)
916 @property
917 def V(self):
918 """
919 Get domain volume
920 """
921 return np.prod([me.L for me in self.axes])
923 def add_axis(self, base, *args, **kwargs):
924 """
925 Add an axis to the domain by deciding on suitable 1D base.
926 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
927 bases for possible arguments.
929 Args:
930 base (str): 1D spectral method
931 """
932 kwargs['useGPU'] = self.useGPU
934 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
935 kwargs['transform_type'] = kwargs.get('transform_type', 'fft')
936 self.axes.append(ChebychevHelper(*args, **kwargs))
937 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
938 self.axes.append(FFTHelper(*args, **kwargs))
939 elif base.lower() in ['ultraspherical', 'gegenbauer']:
940 self.axes.append(UltrasphericalHelper(*args, **kwargs))
941 else:
942 raise NotImplementedError(f'{base=!r} is not implemented!')
943 self.axes[-1].xp = self.xp
944 self.axes[-1].sparse_lib = self.sparse_lib
946 def add_component(self, name):
947 """
948 Add solution component(s).
950 Args:
951 name (str or list of strings): Name(s) of component(s)
952 """
953 if type(name) in [list, tuple]:
954 for me in name:
955 self.add_component(me)
956 elif type(name) in [str]:
957 if name in self.components:
958 raise Exception(f'{name=!r} is already added to this problem!')
959 self.components.append(name)
960 else:
961 raise NotImplementedError
963 def index(self, name):
964 """
965 Get the index of component `name`.
967 Args:
968 name (str or list of strings): Name(s) of component(s)
970 Returns:
971 int: Index of the component
972 """
973 if type(name) in [str, int]:
974 return self.components.index(name)
975 elif type(name) in [list, tuple]:
976 return (self.index(me) for me in name)
977 else:
978 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
980 def get_empty_operator_matrix(self):
981 """
982 Return a matrix of operators to be filled with the connections between the solution components.
984 Returns:
985 list containing sparse zeros
986 """
987 S = len(self.components)
988 O = self.get_Id() * 0
989 return [[O for _ in range(S)] for _ in range(S)]
991 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
992 """
993 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
994 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
995 Typically, the last line that does not contain a BC is the best choice.
996 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
998 Args:
999 axis (int): Axis you want to add the BC to
1000 kind (str): kind of BC, e.g. Dirichlet
1001 line (int): Line you want the BC to go in
1002 scalar (bool): Put the BC in all space positions in the other direction
1004 Returns:
1005 sparse matrix containing the BC
1006 """
1007 sp = scipy.sparse
1009 base = self.axes[axis]
1011 BC = sp.eye(base.N).tolil() * 0
1012 if self.useGPU:
1013 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1014 else:
1015 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1017 ndim = len(self.axes)
1018 if ndim == 1:
1019 return self.sparse_lib.csc_matrix(BC)
1020 elif ndim == 2:
1021 axis2 = (axis + 1) % ndim
1023 if scalar:
1024 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1025 else:
1026 _Id = self.axes[axis2].get_Id()
1028 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1030 if self.useGPU:
1031 Id = Id.get()
1033 mats = [
1034 None,
1035 ] * ndim
1036 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1037 mats[axis2] = Id
1038 return self.sparse_lib.csc_matrix(sp.kron(*mats))
1040 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1041 """
1042 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1043 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1044 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1046 Args:
1047 component (str): Name of the component the BC should act on
1048 equation (str): Name of the equation for the component you want to put the BC in
1049 axis (int): Axis you want to add the BC to
1050 kind (str): kind of BC, e.g. Dirichlet
1051 v: Value of the BC
1052 line (int): Line you want the BC to go in
1053 scalar (bool): Put the BC in all space positions in the other direction
1054 """
1055 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1056 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1058 if scalar:
1059 slices = [self.index(equation)] + [
1060 0,
1061 ] * self.ndim
1062 slices[axis + 1] = line
1063 else:
1064 slices = (
1065 [self.index(equation)]
1066 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1067 + [line]
1068 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1069 )
1070 N = self.axes[axis].N
1071 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1072 self.BC_rhs_mask[(*slices,)] = False
1074 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1075 """
1076 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1077 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1078 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1080 Args:
1081 component (str): Name of the component the BC should act on
1082 equation (str): Name of the equation for the component you want to put the BC in
1083 axis (int): Axis you want to add the BC to
1084 kind (str): kind of BC, e.g. Dirichlet
1085 v: Value of the BC
1086 line (int): Line you want the BC to go in
1087 scalar (bool): Put the BC in all space positions in the other direction
1088 """
1089 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1090 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1091 self.full_BCs += [
1092 {
1093 'component': component,
1094 'equation': equation,
1095 'axis': axis,
1096 'kind': kind,
1097 'v': v,
1098 'line': line,
1099 'scalar': scalar,
1100 **kwargs,
1101 }
1102 ]
1104 if scalar:
1105 slices = [self.index(equation)] + [
1106 0,
1107 ] * self.ndim
1108 slices[axis + 1] = line
1109 if self.comm:
1110 if self.comm.rank == 0:
1111 self.BC_rhs_mask[(*slices,)] = True
1112 else:
1113 self.BC_rhs_mask[(*slices,)] = True
1114 else:
1115 slices = (
1116 [self.index(equation)]
1117 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1118 + [line]
1119 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1120 )
1121 N = self.axes[axis].N
1122 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1123 slices[axis + 1] -= self.local_slice[axis].start
1124 self.BC_rhs_mask[(*slices,)] = True
1126 def setup_BCs(self):
1127 """
1128 Convert the list of lists of BCs to the boundary condition operator.
1129 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1130 condition. This method sets up a suitable sparse matrix to do this.
1131 """
1132 sp = self.sparse_lib
1133 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1134 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1136 diags = self.xp.ones(self.BCs.shape[0])
1137 diags[self.BC_zero_index] = 0
1138 self.BC_line_zero_matrix = sp.diags(diags)
1140 # prepare BCs in spectral space to easily add to the RHS
1141 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1142 self.rhs_BCs_hat = self.transform(rhs_BCs)
1144 def check_BCs(self, u):
1145 """
1146 Check that the solution satisfies the boundary conditions
1148 Args:
1149 u: The solution you want to check
1150 """
1151 assert self.ndim < 3
1152 for axis in range(self.ndim):
1153 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1155 if len(BCs) > 0:
1156 u_hat = self.transform(u, axes=(axis - self.ndim,))
1157 for BC in BCs:
1158 kwargs = {
1159 key: value
1160 for key, value in BC.items()
1161 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1162 }
1164 if axis == 0:
1165 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1166 elif axis == 1:
1167 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1168 want = BC['v']
1169 assert self.xp.allclose(
1170 get, want
1171 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1173 def put_BCs_in_matrix(self, A):
1174 """
1175 Put the boundary conditions in a matrix by replacing rows with BCs.
1176 """
1177 return self.BC_line_zero_matrix @ A + self.BCs
1179 def put_BCs_in_rhs_hat(self, rhs_hat):
1180 """
1181 Put the BCs in the right hand side in spectral space for solving.
1182 This function needs no transforms and caches a mask for faster subsequent use.
1184 Args:
1185 rhs_hat: Right hand side in spectral space
1187 Returns:
1188 rhs in spectral space with BCs
1189 """
1190 if not hasattr(self, '_rhs_hat_zero_mask'):
1191 """
1192 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1193 by the boundary conditions. The mask is then cached.
1194 """
1195 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)
1197 for axis in range(self.ndim):
1198 for bc in self.full_BCs:
1199 slices = (
1200 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1201 + [bc['line']]
1202 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1203 )
1204 if axis == bc['axis']:
1205 _slice = [self.index(bc['equation'])] + slices
1206 N = self.axes[axis].N
1207 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1208 _slice[axis + 1] -= self.local_slice[axis].start
1209 self._rhs_hat_zero_mask[(*_slice,)] = True
1211 rhs_hat[self._rhs_hat_zero_mask] = 0
1212 return rhs_hat + self.rhs_BCs_hat
1214 def put_BCs_in_rhs(self, rhs):
1215 """
1216 Put the BCs in the right hand side for solving.
1217 This function will transform along each axis individually and add all BCs in that axis.
1218 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1220 Args:
1221 rhs: Right hand side in physical space
1223 Returns:
1224 rhs in physical space with BCs
1225 """
1226 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1228 ndim = self.ndim
1230 for axis in range(ndim):
1231 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1233 for bc in self.full_BCs:
1234 slices = (
1235 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1236 + [bc['line']]
1237 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1238 )
1239 if axis == bc['axis']:
1240 _slice = [self.index(bc['equation'])] + slices
1242 N = self.axes[axis].N
1243 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1244 _slice[axis + 1] -= self.local_slice[axis].start
1246 _rhs_hat[(*_slice,)] = bc['v']
1248 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1250 return rhs
1252 def add_equation_lhs(self, A, equation, relations):
1253 """
1254 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1255 that you will convert to an operator later.
1257 Example:
1258 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1259 preconditioning:
1261 .. code-block:: python
1262 helper = SpectralHelper()
1264 helper.add_axis(base='chebychev', N=8)
1265 helper.add_component(['u', 'ux'])
1266 helper.setup_fft()
1268 I = helper.get_Id()
1269 Dx = helper.get_differentiation_matrix(axes=(0,))
1270 T2U = helper.get_basis_change_matrix('T2U')
1272 L_lhs = {
1273 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1274 'u': {'ux': -(T2U @ Dx)},
1275 }
1277 operator = helper.get_empty_operator_matrix()
1278 for line, equation in L_lhs.items():
1279 helper.add_equation_lhs(operator, line, equation)
1281 L = helper.convert_operator_matrix_to_operator(operator)
1283 Args:
1284 A (list of lists of sparse matrices): The operator to be
1285 equation (str): The equation of the component you want this in
1286 relations: (dict): Relations between quantities
1287 """
1288 for k, v in relations.items():
1289 A[self.index(equation)][self.index(k)] = v
1291 def convert_operator_matrix_to_operator(self, M):
1292 """
1293 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1294 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1296 Args:
1297 M (list of lists of sparse matrices): The operator to be
1299 Returns:
1300 sparse linear operator
1301 """
1302 if len(self.components) == 1:
1303 return M[0][0]
1304 else:
1305 return self.sparse_lib.bmat(M, format='csc')
1307 def get_wavenumbers(self):
1308 """
1309 Get grid in spectral space
1310 """
1311 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))][::-1]
1312 return self.xp.meshgrid(*grids)
1314 def get_grid(self):
1315 """
1316 Get grid in physical space
1317 """
1318 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))][::-1]
1319 return self.xp.meshgrid(*grids)
1321 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1322 """
1323 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1325 Args:
1326 axes (tuple): Axes you want to transform over
1327 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1328 padding (tuple): Padding for dealiasing
1329 shape (tuple): Shape of the transform
1331 Returns:
1332 transform
1333 """
1334 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1335 shape = self.global_shape[1:] if shape is None else shape
1336 padding = (
1337 [
1338 1,
1339 ]
1340 * self.ndim
1341 if padding is None
1342 else padding
1343 )
1344 key = (axes, direction, tuple(padding), tuple(shape))
1346 if key not in self.fft_cache.keys():
1347 if self.comm is None:
1348 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1350 if direction == 'forward':
1351 self.fft_cache[key] = self.xp.fft.fftn
1352 elif direction == 'backward':
1353 self.fft_cache[key] = self.xp.fft.ifftn
1354 elif direction == 'object':
1355 self.fft_cache[key] = None
1356 else:
1357 if direction == 'object':
1358 from mpi4py_fft import PFFT
1360 _fft = PFFT(
1361 comm=self.comm,
1362 shape=shape,
1363 axes=sorted(axes),
1364 dtype='D',
1365 collapse=False,
1366 backend=self.fft_backend,
1367 comm_backend=self.fft_comm_backend,
1368 padding=padding,
1369 )
1370 else:
1371 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1373 if direction == 'forward':
1374 self.fft_cache[key] = _fft.forward
1375 elif direction == 'backward':
1376 self.fft_cache[key] = _fft.backward
1377 elif direction == 'object':
1378 self.fft_cache[key] = _fft
1380 return self.fft_cache[key]
1382 def setup_fft(self, real_spectral_coefficients=False):
1383 """
1384 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1385 This must also be called before setting up any BCs.
1387 Args:
1388 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1389 """
1390 if len(self.components) == 0:
1391 self.add_component('u')
1393 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1394 self.local_slice = [slice(0, me.N) for me in self.axes]
1396 axes = tuple(i for i in range(len(self.axes)))
1397 self.fft_obj = self.get_fft(axes=axes, direction='object')
1398 if self.fft_obj is not None:
1399 self.local_slice = self.fft_obj.local_slice(False)
1401 self.init = (
1402 np.empty(shape=self.global_shape)[
1403 (
1404 ...,
1405 *self.local_slice,
1406 )
1407 ].shape,
1408 self.comm,
1409 np.dtype('float'),
1410 )
1411 self.init_forward = (
1412 np.empty(shape=self.global_shape)[
1413 (
1414 ...,
1415 *self.local_slice,
1416 )
1417 ].shape,
1418 self.comm,
1419 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1420 )
1422 self.BC_mat = self.get_empty_operator_matrix()
1423 self.BC_rhs_mask = self.xp.zeros(
1424 shape=self.init[0],
1425 dtype=bool,
1426 )
1428 def _transform_fft(self, u, axes, **kwargs):
1429 """
1430 FFT along `axes`
1432 Args:
1433 u: The solution
1434 axes (tuple): Axes you want to transform over
1436 Returns:
1437 transformed solution
1438 """
1439 # TODO: clean up and try putting more of this in the 1D bases
1440 fft = self.get_fft(axes, 'forward', **kwargs)
1441 return fft(u, axes=axes)
1443 def _transform_dct(self, u, axes, padding=None, **kwargs):
1444 '''
1445 DCT along `axes`.
1446 This will only return real values!
1447 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of
1448 wavenumbers in FFTs.
1450 Args:
1451 u: The solution
1452 axes (tuple): Axes you want to transform over
1454 Returns:
1455 transformed solution
1456 '''
1457 # TODO: clean up and try putting more of this in the 1D bases
1458 if self.debug:
1459 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1461 if len(axes) > 1:
1462 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs)
1463 else:
1464 v = u.copy().astype(complex)
1465 axis = axes[0]
1466 base = self.axes[axis]
1468 shuffle = [slice(0, s, 1) for s in u.shape]
1469 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis])
1470 v = v[(*shuffle,)]
1472 if padding is not None:
1473 shape = list(v.shape)
1474 if ('forward', *padding) in self.fft_dealias_shape_cache.keys():
1475 shape[0] = self.fft_dealias_shape_cache[('forward', *padding)]
1476 elif self.comm:
1477 send_buf = np.array(v.shape[0])
1478 recv_buf = np.array(v.shape[0])
1479 self.comm.Allreduce(send_buf, recv_buf)
1480 shape[0] = int(recv_buf)
1481 fft = self.get_fft(axes, 'forward', shape=shape)
1482 else:
1483 fft = self.get_fft(axes, 'forward', **kwargs)
1485 v = fft(v, axes=axes)
1487 expansion = [np.newaxis for _ in u.shape]
1488 expansion[axis] = slice(0, v.shape[axis], 1)
1490 if padding is not None:
1491 shift = base.get_fft_shift(True, v.shape[axis])
1493 if padding[axis] != 1:
1494 N = int(np.ceil(v.shape[axis] / padding[axis]))
1495 _expansion = [slice(0, n) for n in v.shape]
1496 _expansion[axis] = slice(0, N, 1)
1497 v = v[(*_expansion,)]
1498 else:
1499 shift = base.fft_utils['fwd']['shift']
1501 v *= shift[(*expansion,)]
1503 return v.real
1505 def transform_single_component(self, u, axes=None, padding=None):
1506 """
1507 Transform a single component of the solution
1509 Args:
1510 u data to transform:
1511 axes (tuple): Axes over which to transform
1512 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1514 Returns:
1515 Transformed data
1516 """
1517 # TODO: clean up and try putting more of this in the 1D bases
1518 trfs = {
1519 ChebychevHelper: self._transform_dct,
1520 UltrasphericalHelper: self._transform_dct,
1521 FFTHelper: self._transform_fft,
1522 }
1524 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1525 padding = (
1526 [
1527 1,
1528 ]
1529 * self.ndim
1530 if padding is None
1531 else padding
1532 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only.
1534 result = u.copy().astype(complex)
1535 alignment = self.ndim - 1
1537 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1538 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1539 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1540 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)]
1542 fft = self.get_fft(axes=axes, padding=padding, direction='object')
1543 if fft is not None:
1544 shape = list(fft.global_shape(False))
1546 for trf in range(len(axes_collapsed)):
1547 _axes = axes_collapsed[trf]
1548 base = bases[trf]
1550 if len(_axes) == 0:
1551 continue
1553 for _ax in _axes:
1554 shape[_ax] = self.global_shape[1 + self.ndim + _ax]
1556 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1558 _in = self.get_aligned(
1559 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape
1560 )
1562 alignment = self.ndim + _axes[-1]
1564 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1566 if self.comm is not None:
1567 _out *= np.prod([self.axes[i].N for i in _axes])
1569 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1570 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1]
1571 result = self.get_aligned(
1572 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape
1573 )
1575 return result
1577 def transform(self, u, axes=None, padding=None):
1578 """
1579 Transform all components from physical space to spectral space
1581 Args:
1582 u data to transform:
1583 axes (tuple): Axes over which to transform
1584 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1586 Returns:
1587 Transformed data
1588 """
1589 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1590 padding = (
1591 [
1592 1,
1593 ]
1594 * self.ndim
1595 if padding is None
1596 else padding
1597 )
1599 result = [
1600 None,
1601 ] * self.ncomponents
1602 for comp in self.components:
1603 i = self.index(comp)
1605 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding)
1607 return self.xp.stack(result)
1609 def _transform_ifft(self, u, axes, **kwargs):
1610 # TODO: clean up and try putting more of this in the 1D bases
1611 ifft = self.get_fft(axes, 'backward', **kwargs)
1612 return ifft(u, axes=axes)
1614 def _transform_idct(self, u, axes, padding=None, **kwargs):
1615 '''
1616 This will only ever return real values!
1617 '''
1618 # TODO: clean up and try putting more of this in the 1D bases
1619 if self.debug:
1620 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1622 v = u.copy().astype(complex)
1624 if len(axes) > 1:
1625 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],))
1626 else:
1627 axis = axes[0]
1628 base = self.axes[axis]
1630 if padding is not None:
1631 if padding[axis] != 1:
1632 N_pad = int(np.ceil(v.shape[axis] * padding[axis]))
1633 _pad = [[0, 0] for _ in v.shape]
1634 _pad[axis] = [0, N_pad - base.N]
1635 v = self.xp.pad(v, _pad, 'constant')
1637 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N
1638 else:
1639 shift = base.fft_utils['bck']['shift']
1640 else:
1641 shift = base.fft_utils['bck']['shift']
1643 expansion = [np.newaxis for _ in u.shape]
1644 expansion[axis] = slice(0, v.shape[axis], 1)
1646 v *= shift[(*expansion,)]
1648 if padding is not None:
1649 if padding[axis] != 1:
1650 shape = list(v.shape)
1651 if ('backward', *padding) in self.fft_dealias_shape_cache.keys():
1652 shape[0] = self.fft_dealias_shape_cache[('backward', *padding)]
1653 elif self.comm:
1654 send_buf = np.array(v.shape[0])
1655 recv_buf = np.array(v.shape[0])
1656 self.comm.Allreduce(send_buf, recv_buf)
1657 shape[0] = int(recv_buf)
1658 ifft = self.get_fft(axes, 'backward', shape=shape)
1659 else:
1660 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1661 else:
1662 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1663 v = ifft(v, axes=axes)
1665 shuffle = [slice(0, s, 1) for s in v.shape]
1666 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis])
1667 v = v[(*shuffle,)]
1669 return v.real
1671 def itransform_single_component(self, u, axes=None, padding=None):
1672 """
1673 Inverse transform over single component of the solution
1675 Args:
1676 u data to transform:
1677 axes (tuple): Axes over which to transform
1678 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1680 Returns:
1681 Transformed data
1682 """
1683 # TODO: clean up and try putting more of this in the 1D bases
1684 trfs = {
1685 FFTHelper: self._transform_ifft,
1686 ChebychevHelper: self._transform_idct,
1687 UltrasphericalHelper: self._transform_idct,
1688 }
1690 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1691 padding = (
1692 [
1693 1,
1694 ]
1695 * self.ndim
1696 if padding is None
1697 else padding
1698 )
1700 result = u.copy().astype(complex)
1701 alignment = self.ndim - 1
1703 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1704 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1705 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1706 shape = list(self.global_shape[1:])
1708 for trf in range(len(axes_collapsed)):
1709 _axes = axes_collapsed[trf]
1710 base = bases[trf]
1712 if len(_axes) == 0:
1713 continue
1715 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1717 _in = self.get_aligned(
1718 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape
1719 )
1720 if self.comm is not None:
1721 _in /= np.prod([self.axes[i].N for i in _axes])
1723 alignment = self.ndim + _axes[0]
1725 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1727 for _ax in _axes:
1728 if fft:
1729 shape[_ax] = fft._input_shape[_ax]
1730 else:
1731 shape[_ax] = _out.shape[_ax]
1733 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1734 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0]
1735 result = self.get_aligned(
1736 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape
1737 )
1739 return result
1741 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
1742 """
1743 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT
1744 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include
1745 shape especially, if applicable.
1747 Args:
1748 u: The solution
1749 axis_in (int): Current alignment
1750 axis_out (int): New alignment
1751 fft (mpi4py_fft.PFFT), optional: parallel FFT object
1752 forward (bool): Whether the input is in spectral space or not
1754 Returns:
1755 solution aligned on `axis_in`
1756 """
1757 if self.comm is None or axis_in == axis_out:
1758 return u.copy()
1759 if self.comm.size == 1:
1760 return u.copy()
1762 global_fft = self.get_fft(**kwargs)
1763 axisA = [me.axisA for me in global_fft.transfer]
1764 axisB = [me.axisB for me in global_fft.transfer]
1766 current_axis = axis_in
1768 if axis_in in axisA and axis_out in axisB:
1769 while current_axis != axis_out:
1770 transfer = global_fft.transfer[axisA.index(current_axis)]
1772 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1773 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1774 arrayA[:] = u[:]
1776 transfer.forward(arrayA=arrayA, arrayB=arrayB)
1778 current_axis = transfer.axisB
1779 u = arrayB
1781 return u
1782 elif axis_in in axisB and axis_out in axisA:
1783 while current_axis != axis_out:
1784 transfer = global_fft.transfer[axisB.index(current_axis)]
1786 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1787 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1788 arrayB[:] = u[:]
1790 transfer.backward(arrayA=arrayA, arrayB=arrayB)
1792 current_axis = transfer.axisA
1793 u = arrayA
1795 return u
1796 else: # go the potentially slower route of not reusing transfer classes
1797 from mpi4py_fft import newDistArray
1799 fft = self.get_fft(**kwargs) if fft is None else fft
1801 _in = newDistArray(fft, forward).redistribute(axis_in)
1802 _in[...] = u
1804 return _in.redistribute(axis_out)
1806 def itransform(self, u, axes=None, padding=None):
1807 """
1808 Inverse transform over all components of the solution
1810 Args:
1811 u data to transform:
1812 axes (tuple): Axes over which to transform
1813 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1815 Returns:
1816 Transformed data
1817 """
1818 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1819 padding = (
1820 [
1821 1,
1822 ]
1823 * self.ndim
1824 if padding is None
1825 else padding
1826 )
1828 result = [
1829 None,
1830 ] * self.ncomponents
1831 for comp in self.components:
1832 i = self.index(comp)
1834 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding)
1836 return self.xp.stack(result)
1838 def get_local_slice_of_1D_matrix(self, M, axis):
1839 """
1840 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1841 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can
1842 use this method to get the part corresponding to the modes carried by this rank.
1844 Args:
1845 M (sparse matrix): Global 1D matrix you want to get the local version of
1846 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only.
1848 Returns:
1849 sparse local matrix
1850 """
1851 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]
1853 def get_filter_matrix(self, axis, **kwargs):
1854 """
1855 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1856 admissible.
1858 Returns:
1859 sparse bandpass matrix
1860 """
1861 if self.ndim == 1:
1862 return self.axes[0].get_filter_matrix(**kwargs)
1864 mats = [base.get_Id() for base in self.axes]
1865 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1866 return self.sparse_lib.kron(*mats)
1868 def get_differentiation_matrix(self, axes, **kwargs):
1869 """
1870 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1872 Args:
1873 axes (tuple): Axes along which to differentiate.
1875 Returns:
1876 sparse differentiation matrix
1877 """
1878 sp = self.sparse_lib
1879 ndim = self.ndim
1881 if ndim == 1:
1882 D = self.axes[0].get_differentiation_matrix(**kwargs)
1883 elif ndim == 2:
1884 for axis in axes:
1885 axis2 = (axis + 1) % ndim
1886 D1D = self.axes[axis].get_differentiation_matrix(**kwargs)
1888 if len(axes) > 1:
1889 I1D = sp.eye(self.axes[axis2].N)
1890 else:
1891 I1D = self.axes[axis2].get_Id()
1893 mats = [None] * ndim
1894 mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis)
1895 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1897 if axis == axes[0]:
1898 D = sp.kron(*mats)
1899 else:
1900 D = D @ sp.kron(*mats)
1901 else:
1902 raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!')
1904 return D
1906 def get_integration_matrix(self, axes):
1907 """
1908 Get integration matrix to integrate along specified axis.
1910 Args:
1911 axes (tuple): Axes along which to integrate over.
1913 Returns:
1914 sparse integration matrix
1915 """
1916 sp = self.sparse_lib
1917 ndim = len(self.axes)
1919 if ndim == 1:
1920 S = self.axes[0].get_integration_matrix()
1921 elif ndim == 2:
1922 for axis in axes:
1923 axis2 = (axis + 1) % ndim
1924 S1D = self.axes[axis].get_integration_matrix()
1926 if len(axes) > 1:
1927 I1D = sp.eye(self.axes[axis2].N)
1928 else:
1929 I1D = self.axes[axis2].get_Id()
1931 mats = [None] * ndim
1932 mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis)
1933 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1935 if axis == axes[0]:
1936 S = sp.kron(*mats)
1937 else:
1938 S = S @ sp.kron(*mats)
1939 else:
1940 raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!')
1942 return S
1944 def get_Id(self):
1945 """
1946 Get identity matrix
1948 Returns:
1949 sparse identity matrix
1950 """
1951 sp = self.sparse_lib
1952 ndim = self.ndim
1953 I = sp.eye(np.prod(self.init[0][1:]), dtype=complex)
1955 if ndim == 1:
1956 I = self.axes[0].get_Id()
1957 elif ndim == 2:
1958 for axis in range(ndim):
1959 axis2 = (axis + 1) % ndim
1960 I1D = self.axes[axis].get_Id()
1962 I1D2 = sp.eye(self.axes[axis2].N)
1964 mats = [None] * ndim
1965 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1966 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2)
1968 I = I @ sp.kron(*mats)
1969 else:
1970 raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!')
1972 return I
1974 def get_Dirichlet_recombination_matrix(self, axis=-1):
1975 """
1976 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1978 Args:
1979 axis (int): Axis you discretized with Chebychev
1981 Returns:
1982 sparse matrix
1983 """
1984 sp = self.sparse_lib
1985 ndim = len(self.axes)
1987 if ndim == 1:
1988 C = self.axes[0].get_Dirichlet_recombination_matrix()
1989 elif ndim == 2:
1990 axis2 = (axis + 1) % ndim
1991 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1993 I1D = self.axes[axis2].get_Id()
1995 mats = [None] * ndim
1996 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
1997 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1999 C = sp.kron(*mats)
2000 else:
2001 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
2003 return C
2005 def get_basis_change_matrix(self, axes=None, **kwargs):
2006 """
2007 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
2008 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
2010 Args:
2011 axes (tuple): Axes along which to change basis.
2013 Returns:
2014 sparse basis change matrix
2015 """
2016 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
2018 sp = self.sparse_lib
2019 ndim = len(self.axes)
2021 if ndim == 1:
2022 C = self.axes[0].get_basis_change_matrix(**kwargs)
2023 elif ndim == 2:
2024 for axis in axes:
2025 axis2 = (axis + 1) % ndim
2026 C1D = self.axes[axis].get_basis_change_matrix(**kwargs)
2028 if len(axes) > 1:
2029 I1D = sp.eye(self.axes[axis2].N)
2030 else:
2031 I1D = self.axes[axis2].get_Id()
2033 mats = [None] * ndim
2034 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
2035 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
2037 if axis == axes[0]:
2038 C = sp.kron(*mats)
2039 else:
2040 C = C @ sp.kron(*mats)
2041 else:
2042 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
2044 return C