Coverage for pySDC/helpers/spectral_helper.py: 92%
778 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 13:12 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 13:12 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
7class SpectralHelper1D:
8 """
9 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
10 all bases need to have.
12 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
13 the code for GPUs.
15 Attributes:
16 N (int): Resolution
17 x0 (float): Coordinate of left boundary
18 x1 (float): Coordinate of right boundary
19 L (float): Length of the domain
20 useGPU (bool): Whether to use GPUs
22 """
24 fft_lib = scipy.fft
25 sparse_lib = scipy.sparse
26 linalg = scipy.sparse.linalg
27 xp = np
29 def __init__(self, N, x0=None, x1=None, useGPU=False):
30 """
31 Constructor
33 Args:
34 N (int): Resolution
35 x0 (float): Coordinate of left boundary
36 x1 (float): Coordinate of right boundary
37 useGPU (bool): Whether to use GPUs
38 """
39 self.N = N
40 self.x0 = x0
41 self.x1 = x1
42 self.L = x1 - x0
43 self.useGPU = useGPU
45 if useGPU:
46 self.setup_GPU()
48 @classmethod
49 def setup_GPU(cls):
50 """switch to GPU modules"""
51 import cupy as cp
52 import cupyx.scipy.sparse as sparse_lib
53 import cupyx.scipy.sparse.linalg as linalg
54 import cupyx.scipy.fft as fft_lib
55 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
57 cls.xp = cp
58 cls.sparse_lib = sparse_lib
59 cls.linalg = linalg
60 cls.fft_lib = fft_lib
62 def get_Id(self):
63 """
64 Get identity matrix
66 Returns:
67 sparse diagonal identity matrix
68 """
69 return self.sparse_lib.eye(self.N)
71 def get_zero(self):
72 """
73 Get a matrix with all zeros of the correct size.
75 Returns:
76 sparse matrix with zeros everywhere
77 """
78 return 0 * self.get_Id()
80 def get_differentiation_matrix(self):
81 raise NotImplementedError()
83 def get_integration_matrix(self):
84 raise NotImplementedError()
86 def get_wavenumbers(self):
87 """
88 Get the grid in spectral space
89 """
90 raise NotImplementedError
92 def get_empty_operator_matrix(self, S, O):
93 """
94 Return a matrix of operators to be filled with the connections between the solution components.
96 Args:
97 S (int): Number of components in the solution
98 O (sparse matrix): Zero matrix used for initialization
100 Returns:
101 list of lists containing sparse zeros
102 """
103 return [[O for _ in range(S)] for _ in range(S)]
105 def get_basis_change_matrix(self, *args, **kwargs):
106 """
107 Some spectral discretization change the basis during differentiation. This method can be used to transfer
108 between the various bases.
110 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
111 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
112 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
113 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
114 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
115 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
116 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
118 Returns:
119 sparse bases change matrix
120 """
121 return self.sparse_lib.eye(self.N)
123 def get_BC(self, kind):
124 """
125 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
126 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
127 the matrix are replaced by the boundary conditions as obtained by this method.
129 Args:
130 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
131 individual 1D bases for what is implemented
133 Returns:
134 self.xp.array: Boundary condition
135 """
136 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
138 def get_filter_matrix(self, kmin=0, kmax=None):
139 """
140 Get a bandpass filter.
142 Args:
143 kmin (int): Lower limit of the bandpass filter
144 kmax (int): Upper limit of the bandpass filter
146 Returns:
147 sparse matrix
148 """
150 k = abs(self.get_wavenumbers())
152 kmax = max(k) if kmax is None else kmax
154 mask = self.xp.logical_or(k >= kmax, k < kmin)
156 if self.useGPU:
157 Id = self.get_Id().get()
158 else:
159 Id = self.get_Id()
160 F = Id.tolil()
161 F[:, mask] = 0
162 return F.tocsc()
164 def get_1dgrid(self):
165 """
166 Get the grid in physical space
168 Returns:
169 self.xp.array: Grid
170 """
171 raise NotImplementedError
174class ChebychevHelper(SpectralHelper1D):
175 """
176 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
177 between physical and spectral space by discrete cosine transform.
178 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
179 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
180 be formulated in first order formulation.
182 This implementation is largely based on the Dedalus paper (arXiv:1905.10388).
183 """
185 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
186 """
187 Constructor.
188 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
189 may choose to run on GPUs via the `useGPU` argument.
191 Args:
192 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
193 use the FFT from the library to compute the DCT
194 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
195 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
196 """
197 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
198 self.lin_trf_fac = (x1 - x0) / 2
199 self.lin_trf_off = (x1 + x0) / 2
200 super().__init__(*args, x0=x0, x1=x1, **kwargs)
201 self.transform_type = transform_type
203 if self.transform_type == 'fft':
204 self.get_fft_utils()
206 self.cache = {}
207 self.norm = self.get_norm()
209 def get_1dgrid(self):
210 '''
211 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
212 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
213 need to do an affine transformation before any Chebychev business.
215 Returns:
216 numpy.ndarray: 1D grid
217 '''
218 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
220 def get_wavenumbers(self):
221 """Get the domain in spectral space"""
222 return self.xp.arange(self.N)
224 def get_conv(self, name, N=None):
225 '''
226 Get conversion matrix between different kinds of polynomials. The supported kinds are
227 - T: Chebychev polynomials of first kind
228 - U: Chebychev polynomials of second kind
229 - D: Dirichlet recombination.
231 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
232 Once generates matrices are cached. So feel free to call the method as often as you like.
234 Args:
235 name (str): Conversion code, e.g. 'T2U'
236 N (int): Size of the matrix (optional)
238 Returns:
239 scipy.sparse: Sparse conversion matrix
240 '''
241 if name in self.cache.keys() and not N:
242 return self.cache[name]
244 N = N if N else self.N
245 sp = self.sparse_lib
246 xp = self.xp
248 def get_forward_conv(name):
249 if name == 'T2U':
250 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0
251 mat[:, 0] *= 2
252 elif name == 'D2T':
253 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
254 elif name[0] == name[-1]:
255 mat = self.sparse_lib.eye(self.N)
256 else:
257 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
258 return mat
260 try:
261 mat = get_forward_conv(name)
262 except NotImplementedError as E:
263 try:
264 fwd = get_forward_conv(name[::-1])
265 import scipy.sparse as sp
267 if self.sparse_lib == sp:
268 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
269 else:
270 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
271 except NotImplementedError:
272 raise NotImplementedError from E
274 self.cache[name] = mat
275 return mat
277 def get_basis_change_matrix(self, conv='T2T', **kwargs):
278 """
279 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
280 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
281 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
282 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
284 Args:
285 conv (str): Conversion code, i.e. T2U
287 Returns:
288 Sparse conversion matrix
289 """
290 return self.get_conv(conv)
292 def get_integration_matrix(self, lbnd=0):
293 """
294 Get matrix for integration
296 Args:
297 lbnd (float): Lower bound for integration, only 0 is currently implemented
299 Returns:
300 Sparse integration matrix
301 """
302 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
303 n = self.xp.arange(self.N)
304 if lbnd == 0:
305 S = S.tocsc()
306 S[0, 1::2] = (
307 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
308 * (-1) ** (self.xp.arange(self.N // 2))
309 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
310 ) * self.lin_trf_fac
311 else:
312 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
313 return S
315 def get_differentiation_matrix(self, p=1):
316 '''
317 Keep in mind that the T2T differentiation matrix is dense.
319 Args:
320 p (int): Derivative you want to compute
322 Returns:
323 numpy.ndarray: Differentiation matrix
324 '''
325 D = self.xp.zeros((self.N, self.N))
326 for j in range(self.N):
327 for k in range(j):
328 D[k, j] = 2 * j * ((j - k) % 2)
330 D[0, :] /= 2
331 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
333 def get_norm(self, N=None):
334 '''
335 Get normalization for converting Chebychev coefficients and DCT
337 Args:
338 N (int, optional): Resolution
340 Returns:
341 self.xp.array: Normalization
342 '''
343 N = self.N if N is None else N
344 norm = self.xp.ones(N) / N
345 norm[0] /= 2
346 return norm
348 def get_fft_shuffle(self, forward, N):
349 """
350 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following
351 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it
352 in the complex plane. This function returns a mask to do the shuffling.
354 Args:
355 forward (bool): Whether you want the shuffle for forward transform or backward transform
356 N (int): size of the grid
358 Returns:
359 self.xp.array: Use as mask
360 """
361 xp = self.xp
362 if forward:
363 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2)
364 else:
365 mask = xp.zeros(N, dtype=int)
366 mask[: N - N % 2 : 2] = xp.arange(N // 2)
367 mask[1::2] = N - xp.arange(N // 2) - 1
368 mask[-1] = N // 2
369 return mask
371 def get_fft_shift(self, forward, N):
372 """
373 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT.
375 Args:
376 forward (bool): Whether you want the rotation for forward transform or backward transform
377 N (int): size of the grid
379 Returns:
380 self.xp.array: Rotation
381 """
382 k = self.get_wavenumbers()
383 norm = self.get_norm()
384 xp = self.xp
385 if forward:
386 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm
387 else:
388 shift = xp.exp(1j * np.pi * k / (2 * N))
389 shift[0] = 0.5
390 return shift / norm
392 def get_fft_utils(self):
393 """
394 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep
395 them cached.
396 """
397 self.fft_utils = {
398 'fwd': {},
399 'bck': {},
400 }
402 # forwards transform
403 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N)
404 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N)
406 # backwards transform
407 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N)
408 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N)
410 return self.fft_utils
412 def transform(self, u, axis=-1, **kwargs):
413 """
414 1D DCT along axis. `kwargs` will be passed on to the FFT library.
416 Args:
417 u: Data you want to transform
418 axis (int): Axis you want to transform along
420 Returns:
421 Data in spectral space
422 """
423 if self.transform_type.lower() == 'dct':
424 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm
425 elif self.transform_type.lower() == 'fft':
426 result = u.copy()
428 shuffle = [slice(0, s, 1) for s in u.shape]
429 shuffle[axis] = self.fft_utils['fwd']['shuffle']
431 v = u[(*shuffle,)]
433 V = self.fft_lib.fft(v, axis=axis, **kwargs)
435 expansion = [np.newaxis for _ in u.shape]
436 expansion[axis] = slice(0, u.shape[axis], 1)
438 V *= self.fft_utils['fwd']['shift'][(*expansion,)]
440 result.real[...] = V.real[...]
441 return result
442 else:
443 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}')
445 def itransform(self, u, axis=-1):
446 """
447 1D inverse DCT along axis.
449 Args:
450 u: Data you want to transform
451 axis (int): Axis you want to transform along
453 Returns:
454 Data in physical space
455 """
456 assert self.norm.shape[0] == u.shape[axis]
458 if self.transform_type == 'dct':
459 return self.fft_lib.idct(u / self.norm, axis=axis)
460 elif self.transform_type == 'fft':
461 result = u.copy()
463 expansion = [np.newaxis for _ in u.shape]
464 expansion[axis] = slice(0, u.shape[axis], 1)
466 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis)
468 shuffle = [slice(0, s, 1) for s in u.shape]
469 shuffle[axis] = self.fft_utils['bck']['shuffle']
470 V = v[(*shuffle,)]
472 result.real[...] = V.real[...]
473 return result
474 else:
475 raise NotImplementedError
477 def get_BC(self, kind, **kwargs):
478 """
479 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
480 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
481 set the BC.
483 Args:
484 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
485 """
486 if kind.lower() == 'integral':
487 return self.get_integ_BC_row(**kwargs)
488 elif kind.lower() == 'dirichlet':
489 return self.get_Dirichlet_BC_row(**kwargs)
490 else:
491 return super().get_BC(kind)
493 def get_integ_BC_row(self):
494 """
495 Get a row for generating integral BCs with T polynomials.
496 It returns the values of the integrals of T polynomials over the entire interval.
498 Returns:
499 self.xp.ndarray: Row to put into a matrix
500 """
501 n = self.xp.arange(self.N) + 1
502 me = self.xp.zeros_like(n).astype(float)
503 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
504 me[0] = 2.0
505 return me
507 def get_Dirichlet_BC_row(self, x):
508 """
509 Get a row for generating Dirichlet BCs at x with T polynomials.
510 It returns the values of the T polynomials at x.
512 Args:
513 x (float): Position of the boundary condition
515 Returns:
516 self.xp.ndarray: Row to put into a matrix
517 """
518 if x == -1:
519 return (-1) ** self.xp.arange(self.N)
520 elif x == 1:
521 return self.xp.ones(self.N)
522 elif x == 0:
523 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
524 n[2::4] *= -1
525 return n
526 else:
527 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
529 def get_Dirichlet_recombination_matrix(self):
530 '''
531 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
532 This makes for a good right preconditioner.
534 Returns:
535 scipy.sparse: Sparse conversion matrix
536 '''
537 N = self.N
538 sp = self.sparse_lib
539 xp = self.xp
541 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
544class UltrasphericalHelper(ChebychevHelper):
545 """
546 This implementation follows https://doi.org/10.1137/120865458.
547 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
548 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
549 There are also conversion operators from one derivative basis to the next that are sparse.
551 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
552 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
553 """
555 def get_differentiation_matrix(self, p=1):
556 """
557 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
559 Args:
560 p (int): Order of the derivative
562 Returns:
563 sparse differentiation matrix
564 """
565 sp = self.sparse_lib
566 xp = self.xp
567 N = self.N
568 l = p
569 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
571 def get_S(self, lmbda):
572 """
573 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
574 https://doi.org/10.1137/120865458.
576 Args:
577 lmbda (int): Ingoing derivative base
579 Returns:
580 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
581 """
582 N = self.N
584 if lmbda == 0:
585 sp = scipy.sparse
586 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil()
587 mat[:, 0] *= 2
588 else:
589 sp = self.sparse_lib
590 xp = self.xp
591 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
592 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
593 )
595 return self.sparse_lib.csc_matrix(mat)
597 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
598 """
599 Get a conversion matrix from derivative base `p_in` to `p_out`.
601 Args:
602 p_out (int): Resulting derivative base
603 p_in (int): Ingoing derivative base
604 """
605 mat_fwd = self.sparse_lib.eye(self.N)
606 for i in range(min([p_in, p_out]), max([p_in, p_out])):
607 mat_fwd = self.get_S(i) @ mat_fwd
609 if p_out > p_in:
610 return mat_fwd
612 else:
613 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
614 import scipy.sparse as sp
616 if self.useGPU:
617 mat_fwd = mat_fwd.get()
619 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
621 return self.sparse_lib.csc_matrix(mat_bck)
623 def get_integration_matrix(self):
624 """
625 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
626 integration constant such that integration starts from x=-1.
628 Example:
630 .. code-block:: python
632 import numpy as np
633 from pySDC.helpers.spectral_helper import UltrasphericalHelper
635 N = 4
636 helper = UltrasphericalHelper(N)
637 coeffs = np.random.random(N)
638 coeffs[-1] = 0
640 poly = np.polynomial.Chebyshev(coeffs)
642 S = helper.get_integration_matrix()
643 U_hat = S @ coeffs
644 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
646 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
648 Returns:
649 sparse integration matrix
650 """
651 return (
652 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
653 @ self.get_basis_change_matrix(p_out=1, p_in=0)
654 * self.lin_trf_fac
655 )
657 def get_integration_constant(self, u_hat, axis):
658 """
659 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
661 Args:
662 u_hat: Solution in spectral space
663 axis: Axis you want to integrate over
665 Returns:
666 Integration constant, has one less dimension than `u_hat`
667 """
668 slices = [
669 None,
670 ] * u_hat.ndim
671 slices[axis] = slice(1, u_hat.shape[axis])
672 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
675class FFTHelper(SpectralHelper1D):
676 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
677 """
678 Constructor.
679 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
680 may choose to run on GPUs via the `useGPU` argument.
682 Args:
683 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
684 use the FFT from the library to compute the DCT
685 x0 (float, optional): Coordinate of left boundary
686 x1 (float, optional): Coordinate of right boundary
687 """
688 super().__init__(*args, x0=x0, x1=x1, **kwargs)
690 def get_1dgrid(self):
691 """
692 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
693 """
694 dx = self.L / self.N
695 return self.xp.arange(self.N) * dx + self.x0
697 def get_wavenumbers(self):
698 """
699 Be careful that this ordering is very unintuitive.
700 """
701 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
703 def get_differentiation_matrix(self, p=1):
704 """
705 This matrix is diagonal, allowing to invert concurrently.
707 Args:
708 p (int): Order of the derivative
710 Returns:
711 sparse differentiation matrix
712 """
713 k = self.get_wavenumbers()
715 if self.useGPU:
716 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
717 import scipy.sparse as sp
719 D = self.sparse_lib.diags(1j * k).get()
720 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p))
721 else:
722 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
724 def get_integration_matrix(self, p=1):
725 """
726 Get integration matrix to compute `p`-th integral over the entire domain.
728 Args:
729 p (int): Order of integral you want to compute
731 Returns:
732 sparse integration matrix
733 """
734 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
735 k[0] = 1j * self.L
736 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
738 def transform(self, u, axis=-1, **kwargs):
739 """
740 1D FFT along axis. `kwargs` are passed on to the FFT library.
742 Args:
743 u: Data you want to transform
744 axis (int): Axis you want to transform along
746 Returns:
747 transformed data
748 """
749 return self.fft_lib.fft(u, axis=axis, **kwargs)
751 def itransform(self, u, axis=-1):
752 """
753 Inverse 1D FFT.
755 Args:
756 u: Data you want to transform
757 axis (int): Axis you want to transform along
759 Returns:
760 transformed data
761 """
762 return self.fft_lib.ifft(u, axis=axis)
764 def get_BC(self, kind):
765 """
766 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
767 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
768 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
770 Args:
771 kind ('integral' or 'nyquist'): Kind of BC
773 Returns:
774 self.xp.ndarray: Boundary condition row
775 """
776 if kind.lower() == 'integral':
777 return self.get_integ_BC_row()
778 elif kind.lower() == 'nyquist':
779 assert (
780 self.N % 2 == 0
781 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
782 BC = self.xp.zeros(self.N)
783 BC[self.get_Nyquist_mode_index()] = 1
784 return BC
785 else:
786 return super().get_BC(kind)
788 def get_Nyquist_mode_index(self):
789 """
790 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
791 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
792 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
793 after.
795 Returns:
796 int: Index of the Nyquist mode
797 """
798 k = self.get_wavenumbers()
799 Nyquist_mode = min(k)
800 return self.xp.where(k == Nyquist_mode)[0][0]
802 def get_integ_BC_row(self):
803 """
804 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
805 """
806 me = self.xp.zeros(self.N)
807 me[0] = self.L / self.N
808 return me
811class SpectralHelper:
812 """
813 This class has three functions:
814 - Easily assemble matrices containing multiple equations
815 - Direct product of 1D bases to solve problems in more dimensions
816 - Distribute the FFTs to facilitate concurrency.
818 Attributes:
819 comm (mpi4py.Intracomm): MPI communicator
820 debug (bool): Perform additional checks at extra computational cost
821 useGPU (bool): Whether to use GPUs
822 axes (list): List of 1D bases
823 components (list): List of strings of the names of components in the equations
824 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
825 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
826 BCs (sparse matrix): Matrix containing only the BCs
827 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
828 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
829 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
830 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
831 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
832 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
833 local_slice (slice): Local slice of the solution as in `mpi4py-fft`
834 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
835 init (tuple): This is the same `init` that is used throughout the problem classes
836 init_forward (tuple): This is the equivalent of `init` in spectral space
837 """
839 xp = np
840 fft_lib = scipy.fft
841 sparse_lib = scipy.sparse
842 linalg = scipy.sparse.linalg
843 dtype = mesh
844 fft_backend = 'fftw'
845 fft_comm_backend = 'MPI'
847 @classmethod
848 def setup_GPU(cls):
849 """switch to GPU modules"""
850 import cupy as cp
851 import cupyx.scipy.sparse as sparse_lib
852 import cupyx.scipy.sparse.linalg as linalg
853 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
855 cls.xp = cp
856 cls.sparse_lib = sparse_lib
857 cls.linalg = linalg
859 cls.fft_backend = 'cupy'
860 cls.fft_comm_backend = 'NCCL'
862 cls.dtype = cupy_mesh
864 def __init__(self, comm=None, useGPU=False, debug=False):
865 """
866 Constructor
868 Args:
869 comm (mpi4py.Intracomm): MPI communicator
870 useGPU (bool): Whether to use GPUs
871 debug (bool): Perform additional checks at extra computational cost
872 """
873 self.comm = comm
874 self.debug = debug
875 self.useGPU = useGPU
877 if useGPU:
878 self.setup_GPU()
880 self.axes = []
881 self.components = []
883 self.full_BCs = []
884 self.BC_mat = None
885 self.BCs = None
887 self.fft_cache = {}
888 self.fft_dealias_shape_cache = {}
890 @property
891 def u_init(self):
892 """
893 Get empty data container in physical space
894 """
895 return self.dtype(self.init)
897 @property
898 def u_init_forward(self):
899 """
900 Get empty data container in spectral space
901 """
902 return self.dtype(self.init_forward)
904 @property
905 def shape(self):
906 """
907 Get shape of individual solution component
908 """
909 return self.init[0][1:]
911 @property
912 def ndim(self):
913 return len(self.axes)
915 @property
916 def ncomponents(self):
917 return len(self.components)
919 @property
920 def V(self):
921 """
922 Get domain volume
923 """
924 return np.prod([me.L for me in self.axes])
926 def add_axis(self, base, *args, **kwargs):
927 """
928 Add an axis to the domain by deciding on suitable 1D base.
929 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
930 bases for possible arguments.
932 Args:
933 base (str): 1D spectral method
934 """
935 kwargs['useGPU'] = self.useGPU
937 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
938 kwargs['transform_type'] = kwargs.get('transform_type', 'fft')
939 self.axes.append(ChebychevHelper(*args, **kwargs))
940 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
941 self.axes.append(FFTHelper(*args, **kwargs))
942 elif base.lower() in ['ultraspherical', 'gegenbauer']:
943 self.axes.append(UltrasphericalHelper(*args, **kwargs))
944 else:
945 raise NotImplementedError(f'{base=!r} is not implemented!')
946 self.axes[-1].xp = self.xp
947 self.axes[-1].sparse_lib = self.sparse_lib
949 def add_component(self, name):
950 """
951 Add solution component(s).
953 Args:
954 name (str or list of strings): Name(s) of component(s)
955 """
956 if type(name) in [list, tuple]:
957 for me in name:
958 self.add_component(me)
959 elif type(name) in [str]:
960 if name in self.components:
961 raise Exception(f'{name=!r} is already added to this problem!')
962 self.components.append(name)
963 else:
964 raise NotImplementedError
966 def index(self, name):
967 """
968 Get the index of component `name`.
970 Args:
971 name (str or list of strings): Name(s) of component(s)
973 Returns:
974 int: Index of the component
975 """
976 if type(name) in [str, int]:
977 return self.components.index(name)
978 elif type(name) in [list, tuple]:
979 return (self.index(me) for me in name)
980 else:
981 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
983 def get_empty_operator_matrix(self):
984 """
985 Return a matrix of operators to be filled with the connections between the solution components.
987 Returns:
988 list containing sparse zeros
989 """
990 S = len(self.components)
991 O = self.get_Id() * 0
992 return [[O for _ in range(S)] for _ in range(S)]
994 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
995 """
996 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
997 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
998 Typically, the last line that does not contain a BC is the best choice.
999 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1001 Args:
1002 axis (int): Axis you want to add the BC to
1003 kind (str): kind of BC, e.g. Dirichlet
1004 line (int): Line you want the BC to go in
1005 scalar (bool): Put the BC in all space positions in the other direction
1007 Returns:
1008 sparse matrix containing the BC
1009 """
1010 sp = scipy.sparse
1012 base = self.axes[axis]
1014 BC = sp.eye(base.N).tolil() * 0
1015 if self.useGPU:
1016 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1017 else:
1018 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1020 ndim = len(self.axes)
1021 if ndim == 1:
1022 return self.sparse_lib.csc_matrix(BC)
1023 elif ndim == 2:
1024 axis2 = (axis + 1) % ndim
1026 if scalar:
1027 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1028 else:
1029 _Id = self.axes[axis2].get_Id()
1031 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1033 if self.useGPU:
1034 Id = Id.get()
1036 mats = [
1037 None,
1038 ] * ndim
1039 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1040 mats[axis2] = Id
1041 return self.sparse_lib.csc_matrix(sp.kron(*mats))
1043 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1044 """
1045 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1046 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1047 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1049 Args:
1050 component (str): Name of the component the BC should act on
1051 equation (str): Name of the equation for the component you want to put the BC in
1052 axis (int): Axis you want to add the BC to
1053 kind (str): kind of BC, e.g. Dirichlet
1054 v: Value of the BC
1055 line (int): Line you want the BC to go in
1056 scalar (bool): Put the BC in all space positions in the other direction
1057 """
1058 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1059 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1061 if scalar:
1062 slices = [self.index(equation)] + [
1063 0,
1064 ] * self.ndim
1065 slices[axis + 1] = line
1066 else:
1067 slices = (
1068 [self.index(equation)]
1069 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1070 + [line]
1071 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1072 )
1073 N = self.axes[axis].N
1074 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1075 self.BC_rhs_mask[(*slices,)] = False
1077 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1078 """
1079 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1080 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1081 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1083 Args:
1084 component (str): Name of the component the BC should act on
1085 equation (str): Name of the equation for the component you want to put the BC in
1086 axis (int): Axis you want to add the BC to
1087 kind (str): kind of BC, e.g. Dirichlet
1088 v: Value of the BC
1089 line (int): Line you want the BC to go in
1090 scalar (bool): Put the BC in all space positions in the other direction
1091 """
1092 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1093 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1094 self.full_BCs += [
1095 {
1096 'component': component,
1097 'equation': equation,
1098 'axis': axis,
1099 'kind': kind,
1100 'v': v,
1101 'line': line,
1102 'scalar': scalar,
1103 **kwargs,
1104 }
1105 ]
1107 if scalar:
1108 slices = [self.index(equation)] + [
1109 0,
1110 ] * self.ndim
1111 slices[axis + 1] = line
1112 if self.comm:
1113 if self.comm.rank == 0:
1114 self.BC_rhs_mask[(*slices,)] = True
1115 else:
1116 self.BC_rhs_mask[(*slices,)] = True
1117 else:
1118 slices = (
1119 [self.index(equation)]
1120 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1121 + [line]
1122 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1123 )
1124 N = self.axes[axis].N
1125 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1126 slices[axis + 1] -= self.local_slice[axis].start
1127 self.BC_rhs_mask[(*slices,)] = True
1129 def setup_BCs(self):
1130 """
1131 Convert the list of lists of BCs to the boundary condition operator.
1132 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1133 condition. This method sets up a suitable sparse matrix to do this.
1134 """
1135 sp = self.sparse_lib
1136 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1137 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1139 diags = self.xp.ones(self.BCs.shape[0])
1140 diags[self.BC_zero_index] = 0
1141 self.BC_line_zero_matrix = sp.diags(diags)
1143 # prepare BCs in spectral space to easily add to the RHS
1144 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1145 self.rhs_BCs_hat = self.transform(rhs_BCs)
1147 def check_BCs(self, u):
1148 """
1149 Check that the solution satisfies the boundary conditions
1151 Args:
1152 u: The solution you want to check
1153 """
1154 assert self.ndim < 3
1155 for axis in range(self.ndim):
1156 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1158 if len(BCs) > 0:
1159 u_hat = self.transform(u, axes=(axis - self.ndim,))
1160 for BC in BCs:
1161 kwargs = {
1162 key: value
1163 for key, value in BC.items()
1164 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1165 }
1167 if axis == 0:
1168 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1169 elif axis == 1:
1170 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1171 want = BC['v']
1172 assert self.xp.allclose(
1173 get, want
1174 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1176 def put_BCs_in_matrix(self, A):
1177 """
1178 Put the boundary conditions in a matrix by replacing rows with BCs.
1179 """
1180 return self.BC_line_zero_matrix @ A + self.BCs
1182 def put_BCs_in_rhs_hat(self, rhs_hat):
1183 """
1184 Put the BCs in the right hand side in spectral space for solving.
1185 This function needs no transforms and caches a mask for faster subsequent use.
1187 Args:
1188 rhs_hat: Right hand side in spectral space
1190 Returns:
1191 rhs in spectral space with BCs
1192 """
1193 if not hasattr(self, '_rhs_hat_zero_mask'):
1194 """
1195 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1196 by the boundary conditions. The mask is then cached.
1197 """
1198 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)
1200 for axis in range(self.ndim):
1201 for bc in self.full_BCs:
1202 slices = (
1203 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1204 + [bc['line']]
1205 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1206 )
1207 if axis == bc['axis']:
1208 _slice = [self.index(bc['equation'])] + slices
1209 N = self.axes[axis].N
1210 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1211 _slice[axis + 1] -= self.local_slice[axis].start
1212 self._rhs_hat_zero_mask[(*_slice,)] = True
1214 rhs_hat[self._rhs_hat_zero_mask] = 0
1215 return rhs_hat + self.rhs_BCs_hat
1217 def put_BCs_in_rhs(self, rhs):
1218 """
1219 Put the BCs in the right hand side for solving.
1220 This function will transform along each axis individually and add all BCs in that axis.
1221 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1223 Args:
1224 rhs: Right hand side in physical space
1226 Returns:
1227 rhs in physical space with BCs
1228 """
1229 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1231 ndim = self.ndim
1233 for axis in range(ndim):
1234 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1236 for bc in self.full_BCs:
1237 slices = (
1238 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1239 + [bc['line']]
1240 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1241 )
1242 if axis == bc['axis']:
1243 _slice = [self.index(bc['equation'])] + slices
1245 N = self.axes[axis].N
1246 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1247 _slice[axis + 1] -= self.local_slice[axis].start
1249 _rhs_hat[(*_slice,)] = bc['v']
1251 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1253 return rhs
1255 def add_equation_lhs(self, A, equation, relations):
1256 """
1257 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1258 that you will convert to an operator later.
1260 Example:
1261 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1262 preconditioning:
1264 .. code-block:: python
1265 helper = SpectralHelper()
1267 helper.add_axis(base='chebychev', N=8)
1268 helper.add_component(['u', 'ux'])
1269 helper.setup_fft()
1271 I = helper.get_Id()
1272 Dx = helper.get_differentiation_matrix(axes=(0,))
1273 T2U = helper.get_basis_change_matrix('T2U')
1275 L_lhs = {
1276 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1277 'u': {'ux': -(T2U @ Dx)},
1278 }
1280 operator = helper.get_empty_operator_matrix()
1281 for line, equation in L_lhs.items():
1282 helper.add_equation_lhs(operator, line, equation)
1284 L = helper.convert_operator_matrix_to_operator(operator)
1286 Args:
1287 A (list of lists of sparse matrices): The operator to be
1288 equation (str): The equation of the component you want this in
1289 relations: (dict): Relations between quantities
1290 """
1291 for k, v in relations.items():
1292 A[self.index(equation)][self.index(k)] = v
1294 def convert_operator_matrix_to_operator(self, M):
1295 """
1296 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1297 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1299 Args:
1300 M (list of lists of sparse matrices): The operator to be
1302 Returns:
1303 sparse linear operator
1304 """
1305 if len(self.components) == 1:
1306 return M[0][0]
1307 else:
1308 return self.sparse_lib.bmat(M, format='csc')
1310 def get_wavenumbers(self):
1311 """
1312 Get grid in spectral space
1313 """
1314 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))][::-1]
1315 return self.xp.meshgrid(*grids)
1317 def get_grid(self):
1318 """
1319 Get grid in physical space
1320 """
1321 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))][::-1]
1322 return self.xp.meshgrid(*grids)
1324 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1325 """
1326 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1328 Args:
1329 axes (tuple): Axes you want to transform over
1330 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1331 padding (tuple): Padding for dealiasing
1332 shape (tuple): Shape of the transform
1334 Returns:
1335 transform
1336 """
1337 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1338 shape = self.global_shape[1:] if shape is None else shape
1339 padding = (
1340 [
1341 1,
1342 ]
1343 * self.ndim
1344 if padding is None
1345 else padding
1346 )
1347 key = (axes, direction, tuple(padding), tuple(shape))
1349 if key not in self.fft_cache.keys():
1350 if self.comm is None:
1351 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1353 if direction == 'forward':
1354 self.fft_cache[key] = self.xp.fft.fftn
1355 elif direction == 'backward':
1356 self.fft_cache[key] = self.xp.fft.ifftn
1357 elif direction == 'object':
1358 self.fft_cache[key] = None
1359 else:
1360 if direction == 'object':
1361 from mpi4py_fft import PFFT
1363 _fft = PFFT(
1364 comm=self.comm,
1365 shape=shape,
1366 axes=sorted(axes),
1367 dtype='D',
1368 collapse=False,
1369 backend=self.fft_backend,
1370 comm_backend=self.fft_comm_backend,
1371 padding=padding,
1372 )
1373 else:
1374 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1376 if direction == 'forward':
1377 self.fft_cache[key] = _fft.forward
1378 elif direction == 'backward':
1379 self.fft_cache[key] = _fft.backward
1380 elif direction == 'object':
1381 self.fft_cache[key] = _fft
1383 return self.fft_cache[key]
1385 def setup_fft(self, real_spectral_coefficients=False):
1386 """
1387 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1388 This must also be called before setting up any BCs.
1390 Args:
1391 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1392 """
1393 if len(self.components) == 0:
1394 self.add_component('u')
1396 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1397 self.local_slice = [slice(0, me.N) for me in self.axes]
1399 axes = tuple(i for i in range(len(self.axes)))
1400 self.fft_obj = self.get_fft(axes=axes, direction='object')
1401 if self.fft_obj is not None:
1402 self.local_slice = self.fft_obj.local_slice(False)
1404 self.init = (
1405 np.empty(shape=self.global_shape)[
1406 (
1407 ...,
1408 *self.local_slice,
1409 )
1410 ].shape,
1411 self.comm,
1412 np.dtype('float'),
1413 )
1414 self.init_forward = (
1415 np.empty(shape=self.global_shape)[
1416 (
1417 ...,
1418 *self.local_slice,
1419 )
1420 ].shape,
1421 self.comm,
1422 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1423 )
1425 self.BC_mat = self.get_empty_operator_matrix()
1426 self.BC_rhs_mask = self.xp.zeros(
1427 shape=self.init[0],
1428 dtype=bool,
1429 )
1431 def _transform_fft(self, u, axes, **kwargs):
1432 """
1433 FFT along `axes`
1435 Args:
1436 u: The solution
1437 axes (tuple): Axes you want to transform over
1439 Returns:
1440 transformed solution
1441 """
1442 # TODO: clean up and try putting more of this in the 1D bases
1443 fft = self.get_fft(axes, 'forward', **kwargs)
1444 return fft(u, axes=axes)
1446 def _transform_dct(self, u, axes, padding=None, **kwargs):
1447 '''
1448 DCT along `axes`.
1449 This will only return real values!
1450 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of
1451 wavenumbers in FFTs.
1453 Args:
1454 u: The solution
1455 axes (tuple): Axes you want to transform over
1457 Returns:
1458 transformed solution
1459 '''
1460 # TODO: clean up and try putting more of this in the 1D bases
1461 if self.debug:
1462 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1464 if len(axes) > 1:
1465 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs)
1466 else:
1467 v = u.copy().astype(complex)
1468 axis = axes[0]
1469 base = self.axes[axis]
1471 shuffle = [slice(0, s, 1) for s in u.shape]
1472 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis])
1473 v = v[(*shuffle,)]
1475 if padding is not None:
1476 shape = list(v.shape)
1477 if ('forward', *padding) in self.fft_dealias_shape_cache.keys():
1478 shape[0] = self.fft_dealias_shape_cache[('forward', *padding)]
1479 elif self.comm:
1480 send_buf = np.array(v.shape[0])
1481 recv_buf = np.array(v.shape[0])
1482 self.comm.Allreduce(send_buf, recv_buf)
1483 shape[0] = int(recv_buf)
1484 fft = self.get_fft(axes, 'forward', shape=shape)
1485 else:
1486 fft = self.get_fft(axes, 'forward', **kwargs)
1488 v = fft(v, axes=axes)
1490 expansion = [np.newaxis for _ in u.shape]
1491 expansion[axis] = slice(0, v.shape[axis], 1)
1493 if padding is not None:
1494 shift = base.get_fft_shift(True, v.shape[axis])
1496 if padding[axis] != 1:
1497 N = int(np.ceil(v.shape[axis] / padding[axis]))
1498 _expansion = [slice(0, n) for n in v.shape]
1499 _expansion[axis] = slice(0, N, 1)
1500 v = v[(*_expansion,)]
1501 else:
1502 shift = base.fft_utils['fwd']['shift']
1504 v *= shift[(*expansion,)]
1506 return v.real
1508 def transform_single_component(self, u, axes=None, padding=None):
1509 """
1510 Transform a single component of the solution
1512 Args:
1513 u data to transform:
1514 axes (tuple): Axes over which to transform
1515 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1517 Returns:
1518 Transformed data
1519 """
1520 # TODO: clean up and try putting more of this in the 1D bases
1521 trfs = {
1522 ChebychevHelper: self._transform_dct,
1523 UltrasphericalHelper: self._transform_dct,
1524 FFTHelper: self._transform_fft,
1525 }
1527 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1528 padding = (
1529 [
1530 1,
1531 ]
1532 * self.ndim
1533 if padding is None
1534 else padding
1535 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only.
1537 result = u.copy().astype(complex)
1538 alignment = self.ndim - 1
1540 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1541 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1542 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1543 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)]
1545 fft = self.get_fft(axes=axes, padding=padding, direction='object')
1546 if fft is not None:
1547 shape = list(fft.global_shape(False))
1549 for trf in range(len(axes_collapsed)):
1550 _axes = axes_collapsed[trf]
1551 base = bases[trf]
1553 if len(_axes) == 0:
1554 continue
1556 for _ax in _axes:
1557 shape[_ax] = self.global_shape[1 + self.ndim + _ax]
1559 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1561 _in = self.get_aligned(
1562 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape
1563 )
1565 alignment = self.ndim + _axes[-1]
1567 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1569 if self.comm is not None:
1570 _out *= np.prod([self.axes[i].N for i in _axes])
1572 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1573 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1]
1574 result = self.get_aligned(
1575 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape
1576 )
1578 return result
1580 def transform(self, u, axes=None, padding=None):
1581 """
1582 Transform all components from physical space to spectral space
1584 Args:
1585 u data to transform:
1586 axes (tuple): Axes over which to transform
1587 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1589 Returns:
1590 Transformed data
1591 """
1592 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1593 padding = (
1594 [
1595 1,
1596 ]
1597 * self.ndim
1598 if padding is None
1599 else padding
1600 )
1602 result = [
1603 None,
1604 ] * self.ncomponents
1605 for comp in self.components:
1606 i = self.index(comp)
1608 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding)
1610 return self.xp.stack(result)
1612 def _transform_ifft(self, u, axes, **kwargs):
1613 # TODO: clean up and try putting more of this in the 1D bases
1614 ifft = self.get_fft(axes, 'backward', **kwargs)
1615 return ifft(u, axes=axes)
1617 def _transform_idct(self, u, axes, padding=None, **kwargs):
1618 '''
1619 This will only ever return real values!
1620 '''
1621 # TODO: clean up and try putting more of this in the 1D bases
1622 if self.debug:
1623 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1625 v = u.copy().astype(complex)
1627 if len(axes) > 1:
1628 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],))
1629 else:
1630 axis = axes[0]
1631 base = self.axes[axis]
1633 if padding is not None:
1634 if padding[axis] != 1:
1635 N_pad = int(np.ceil(v.shape[axis] * padding[axis]))
1636 _pad = [[0, 0] for _ in v.shape]
1637 _pad[axis] = [0, N_pad - base.N]
1638 v = self.xp.pad(v, _pad, 'constant')
1640 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N
1641 else:
1642 shift = base.fft_utils['bck']['shift']
1643 else:
1644 shift = base.fft_utils['bck']['shift']
1646 expansion = [np.newaxis for _ in u.shape]
1647 expansion[axis] = slice(0, v.shape[axis], 1)
1649 v *= shift[(*expansion,)]
1651 if padding is not None:
1652 if padding[axis] != 1:
1653 shape = list(v.shape)
1654 if ('backward', *padding) in self.fft_dealias_shape_cache.keys():
1655 shape[0] = self.fft_dealias_shape_cache[('backward', *padding)]
1656 elif self.comm:
1657 send_buf = np.array(v.shape[0])
1658 recv_buf = np.array(v.shape[0])
1659 self.comm.Allreduce(send_buf, recv_buf)
1660 shape[0] = int(recv_buf)
1661 ifft = self.get_fft(axes, 'backward', shape=shape)
1662 else:
1663 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1664 else:
1665 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1666 v = ifft(v, axes=axes)
1668 shuffle = [slice(0, s, 1) for s in v.shape]
1669 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis])
1670 v = v[(*shuffle,)]
1672 return v.real
1674 def itransform_single_component(self, u, axes=None, padding=None):
1675 """
1676 Inverse transform over single component of the solution
1678 Args:
1679 u data to transform:
1680 axes (tuple): Axes over which to transform
1681 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1683 Returns:
1684 Transformed data
1685 """
1686 # TODO: clean up and try putting more of this in the 1D bases
1687 trfs = {
1688 FFTHelper: self._transform_ifft,
1689 ChebychevHelper: self._transform_idct,
1690 UltrasphericalHelper: self._transform_idct,
1691 }
1693 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1694 padding = (
1695 [
1696 1,
1697 ]
1698 * self.ndim
1699 if padding is None
1700 else padding
1701 )
1703 result = u.copy().astype(complex)
1704 alignment = self.ndim - 1
1706 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1707 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1708 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1709 shape = list(self.global_shape[1:])
1711 for trf in range(len(axes_collapsed)):
1712 _axes = axes_collapsed[trf]
1713 base = bases[trf]
1715 if len(_axes) == 0:
1716 continue
1718 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1720 _in = self.get_aligned(
1721 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape
1722 )
1723 if self.comm is not None:
1724 _in /= np.prod([self.axes[i].N for i in _axes])
1726 alignment = self.ndim + _axes[0]
1728 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1730 for _ax in _axes:
1731 if fft:
1732 shape[_ax] = fft._input_shape[_ax]
1733 else:
1734 shape[_ax] = _out.shape[_ax]
1736 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1737 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0]
1738 result = self.get_aligned(
1739 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape
1740 )
1742 return result
1744 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
1745 """
1746 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT
1747 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include
1748 shape especially, if applicable.
1750 Args:
1751 u: The solution
1752 axis_in (int): Current alignment
1753 axis_out (int): New alignment
1754 fft (mpi4py_fft.PFFT), optional: parallel FFT object
1755 forward (bool): Whether the input is in spectral space or not
1757 Returns:
1758 solution aligned on `axis_in`
1759 """
1760 if self.comm is None or axis_in == axis_out:
1761 return u.copy()
1762 if self.comm.size == 1:
1763 return u.copy()
1765 global_fft = self.get_fft(**kwargs)
1766 axisA = [me.axisA for me in global_fft.transfer]
1767 axisB = [me.axisB for me in global_fft.transfer]
1769 current_axis = axis_in
1771 if axis_in in axisA and axis_out in axisB:
1772 while current_axis != axis_out:
1773 transfer = global_fft.transfer[axisA.index(current_axis)]
1775 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1776 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1777 arrayA[:] = u[:]
1779 transfer.forward(arrayA=arrayA, arrayB=arrayB)
1781 current_axis = transfer.axisB
1782 u = arrayB
1784 return u
1785 elif axis_in in axisB and axis_out in axisA:
1786 while current_axis != axis_out:
1787 transfer = global_fft.transfer[axisB.index(current_axis)]
1789 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1790 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1791 arrayB[:] = u[:]
1793 transfer.backward(arrayA=arrayA, arrayB=arrayB)
1795 current_axis = transfer.axisA
1796 u = arrayA
1798 return u
1799 else: # go the potentially slower route of not reusing transfer classes
1800 from mpi4py_fft import newDistArray
1802 fft = self.get_fft(**kwargs) if fft is None else fft
1804 _in = newDistArray(fft, forward).redistribute(axis_in)
1805 _in[...] = u
1807 return _in.redistribute(axis_out)
1809 def itransform(self, u, axes=None, padding=None):
1810 """
1811 Inverse transform over all components of the solution
1813 Args:
1814 u data to transform:
1815 axes (tuple): Axes over which to transform
1816 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1818 Returns:
1819 Transformed data
1820 """
1821 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1822 padding = (
1823 [
1824 1,
1825 ]
1826 * self.ndim
1827 if padding is None
1828 else padding
1829 )
1831 result = [
1832 None,
1833 ] * self.ncomponents
1834 for comp in self.components:
1835 i = self.index(comp)
1837 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding)
1839 return self.xp.stack(result)
1841 def get_local_slice_of_1D_matrix(self, M, axis):
1842 """
1843 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1844 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can
1845 use this method to get the part corresponding to the modes carried by this rank.
1847 Args:
1848 M (sparse matrix): Global 1D matrix you want to get the local version of
1849 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only.
1851 Returns:
1852 sparse local matrix
1853 """
1854 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]
1856 def get_filter_matrix(self, axis, **kwargs):
1857 """
1858 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1859 admissible.
1861 Returns:
1862 sparse bandpass matrix
1863 """
1864 if self.ndim == 1:
1865 return self.axes[0].get_filter_matrix(**kwargs)
1867 mats = [base.get_Id() for base in self.axes]
1868 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1869 return self.sparse_lib.kron(*mats)
1871 def get_differentiation_matrix(self, axes, **kwargs):
1872 """
1873 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1875 Args:
1876 axes (tuple): Axes along which to differentiate.
1878 Returns:
1879 sparse differentiation matrix
1880 """
1881 sp = self.sparse_lib
1882 ndim = self.ndim
1884 if ndim == 1:
1885 D = self.axes[0].get_differentiation_matrix(**kwargs)
1886 elif ndim == 2:
1887 for axis in axes:
1888 axis2 = (axis + 1) % ndim
1889 D1D = self.axes[axis].get_differentiation_matrix(**kwargs)
1891 if len(axes) > 1:
1892 I1D = sp.eye(self.axes[axis2].N)
1893 else:
1894 I1D = self.axes[axis2].get_Id()
1896 mats = [None] * ndim
1897 mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis)
1898 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1900 if axis == axes[0]:
1901 D = sp.kron(*mats)
1902 else:
1903 D = D @ sp.kron(*mats)
1904 else:
1905 raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!')
1907 return D
1909 def get_integration_matrix(self, axes):
1910 """
1911 Get integration matrix to integrate along specified axis.
1913 Args:
1914 axes (tuple): Axes along which to integrate over.
1916 Returns:
1917 sparse integration matrix
1918 """
1919 sp = self.sparse_lib
1920 ndim = len(self.axes)
1922 if ndim == 1:
1923 S = self.axes[0].get_integration_matrix()
1924 elif ndim == 2:
1925 for axis in axes:
1926 axis2 = (axis + 1) % ndim
1927 S1D = self.axes[axis].get_integration_matrix()
1929 if len(axes) > 1:
1930 I1D = sp.eye(self.axes[axis2].N)
1931 else:
1932 I1D = self.axes[axis2].get_Id()
1934 mats = [None] * ndim
1935 mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis)
1936 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1938 if axis == axes[0]:
1939 S = sp.kron(*mats)
1940 else:
1941 S = S @ sp.kron(*mats)
1942 else:
1943 raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!')
1945 return S
1947 def get_Id(self):
1948 """
1949 Get identity matrix
1951 Returns:
1952 sparse identity matrix
1953 """
1954 sp = self.sparse_lib
1955 ndim = self.ndim
1956 I = sp.eye(np.prod(self.init[0][1:]), dtype=complex)
1958 if ndim == 1:
1959 I = self.axes[0].get_Id()
1960 elif ndim == 2:
1961 for axis in range(ndim):
1962 axis2 = (axis + 1) % ndim
1963 I1D = self.axes[axis].get_Id()
1965 I1D2 = sp.eye(self.axes[axis2].N)
1967 mats = [None] * ndim
1968 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1969 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2)
1971 I = I @ sp.kron(*mats)
1972 else:
1973 raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!')
1975 return I
1977 def get_Dirichlet_recombination_matrix(self, axis=-1):
1978 """
1979 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1981 Args:
1982 axis (int): Axis you discretized with Chebychev
1984 Returns:
1985 sparse matrix
1986 """
1987 sp = self.sparse_lib
1988 ndim = len(self.axes)
1990 if ndim == 1:
1991 C = self.axes[0].get_Dirichlet_recombination_matrix()
1992 elif ndim == 2:
1993 axis2 = (axis + 1) % ndim
1994 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1996 I1D = self.axes[axis2].get_Id()
1998 mats = [None] * ndim
1999 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
2000 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
2002 C = sp.kron(*mats)
2003 else:
2004 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
2006 return C
2008 def get_basis_change_matrix(self, axes=None, **kwargs):
2009 """
2010 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
2011 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
2013 Args:
2014 axes (tuple): Axes along which to change basis.
2016 Returns:
2017 sparse basis change matrix
2018 """
2019 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
2021 sp = self.sparse_lib
2022 ndim = len(self.axes)
2024 if ndim == 1:
2025 C = self.axes[0].get_basis_change_matrix(**kwargs)
2026 elif ndim == 2:
2027 for axis in axes:
2028 axis2 = (axis + 1) % ndim
2029 C1D = self.axes[axis].get_basis_change_matrix(**kwargs)
2031 if len(axes) > 1:
2032 I1D = sp.eye(self.axes[axis2].N)
2033 else:
2034 I1D = self.axes[axis2].get_Id()
2036 mats = [None] * ndim
2037 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
2038 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
2040 if axis == axes[0]:
2041 C = sp.kron(*mats)
2042 else:
2043 C = C @ sp.kron(*mats)
2044 else:
2045 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
2047 return C