Coverage for pySDC/helpers/spectral_helper.py: 92%
773 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
7class SpectralHelper1D:
8 """
9 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
10 all bases need to have.
12 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
13 the code for GPUs.
15 Attributes:
16 N (int): Resolution
17 x0 (float): Coordinate of left boundary
18 x1 (float): Coordinate of right boundary
19 L (float): Length of the domain
20 useGPU (bool): Whether to use GPUs
22 """
24 fft_lib = scipy.fft
25 sparse_lib = scipy.sparse
26 linalg = scipy.sparse.linalg
27 xp = np
29 def __init__(self, N, x0=None, x1=None, useGPU=False):
30 """
31 Constructor
33 Args:
34 N (int): Resolution
35 x0 (float): Coordinate of left boundary
36 x1 (float): Coordinate of right boundary
37 useGPU (bool): Whether to use GPUs
38 """
39 self.N = N
40 self.x0 = x0
41 self.x1 = x1
42 self.L = x1 - x0
43 self.useGPU = useGPU
45 if useGPU:
46 self.setup_GPU()
48 @classmethod
49 def setup_GPU(cls):
50 """switch to GPU modules"""
51 import cupy as cp
52 import cupyx.scipy.sparse as sparse_lib
53 import cupyx.scipy.sparse.linalg as linalg
54 import cupyx.scipy.fft as fft_lib
55 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
57 cls.xp = cp
58 cls.sparse_lib = sparse_lib
59 cls.linalg = linalg
60 cls.fft_lib = fft_lib
62 def get_Id(self):
63 """
64 Get identity matrix
66 Returns:
67 sparse diagonal identity matrix
68 """
69 return self.sparse_lib.eye(self.N)
71 def get_zero(self):
72 """
73 Get a matrix with all zeros of the correct size.
75 Returns:
76 sparse matrix with zeros everywhere
77 """
78 return 0 * self.get_Id()
80 def get_differentiation_matrix(self):
81 raise NotImplementedError()
83 def get_integration_matrix(self):
84 raise NotImplementedError()
86 def get_wavenumbers(self):
87 """
88 Get the grid in spectral space
89 """
90 raise NotImplementedError
92 def get_empty_operator_matrix(self, S, O):
93 """
94 Return a matrix of operators to be filled with the connections between the solution components.
96 Args:
97 S (int): Number of components in the solution
98 O (sparse matrix): Zero matrix used for initialization
100 Returns:
101 list of lists containing sparse zeros
102 """
103 return [[O for _ in range(S)] for _ in range(S)]
105 def get_basis_change_matrix(self, *args, **kwargs):
106 """
107 Some spectral discretization change the basis during differentiation. This method can be used to transfer
108 between the various bases.
110 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
111 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
112 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
113 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
114 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
115 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
116 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
118 Returns:
119 sparse bases change matrix
120 """
121 return self.sparse_lib.eye(self.N)
123 def get_BC(self, kind):
124 """
125 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
126 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
127 the matrix are replaced by the boundary conditions as obtained by this method.
129 Args:
130 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
131 individual 1D bases for what is implemented
133 Returns:
134 self.xp.array: Boundary condition
135 """
136 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
138 def get_filter_matrix(self, kmin=0, kmax=None):
139 """
140 Get a bandpass filter.
142 Args:
143 kmin (int): Lower limit of the bandpass filter
144 kmax (int): Upper limit of the bandpass filter
146 Returns:
147 sparse matrix
148 """
150 k = abs(self.get_wavenumbers())
152 kmax = max(k) if kmax is None else kmax
154 mask = self.xp.logical_or(k >= kmax, k < kmin)
156 if self.useGPU:
157 Id = self.get_Id().get()
158 else:
159 Id = self.get_Id()
160 F = Id.tolil()
161 F[:, mask] = 0
162 return F.tocsc()
164 def get_1dgrid(self):
165 """
166 Get the grid in physical space
168 Returns:
169 self.xp.array: Grid
170 """
171 raise NotImplementedError
174class ChebychevHelper(SpectralHelper1D):
175 """
176 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
177 between physical and spectral space by discrete cosine transform.
178 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
179 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
180 be formulated in first order formulation.
182 This implementation is largely based on the Dedalus paper (arXiv:1905.10388).
183 """
185 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs):
186 """
187 Constructor.
188 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
189 may choose to run on GPUs via the `useGPU` argument.
191 Args:
192 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
193 use the FFT from the library to compute the DCT
194 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
195 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
196 """
197 assert x0 == -1
198 assert x1 == 1
199 super().__init__(*args, x0=x0, x1=x1, **kwargs)
200 self.transform_type = transform_type
202 if self.transform_type == 'fft':
203 self.get_fft_utils()
205 self.cache = {}
206 self.norm = self.get_norm()
208 def get_1dgrid(self):
209 '''
210 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
211 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
212 need to do an affine transformation before any Chebychev business.
214 Returns:
215 numpy.ndarray: 1D grid
216 '''
217 return self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5))
219 def get_wavenumbers(self):
220 """Get the domain in spectral space"""
221 return self.xp.arange(self.N)
223 def get_conv(self, name, N=None):
224 '''
225 Get conversion matrix between different kinds of polynomials. The supported kinds are
226 - T: Chebychev polynomials of first kind
227 - U: Chebychev polynomials of second kind
228 - D: Dirichlet recombination.
230 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
231 Once generates matrices are cached. So feel free to call the method as often as you like.
233 Args:
234 name (str): Conversion code, e.g. 'T2U'
235 N (int): Size of the matrix (optional)
237 Returns:
238 scipy.sparse: Sparse conversion matrix
239 '''
240 if name in self.cache.keys() and not N:
241 return self.cache[name]
243 N = N if N else self.N
244 sp = self.sparse_lib
245 xp = self.xp
247 def get_forward_conv(name):
248 if name == 'T2U':
249 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0
250 mat[:, 0] *= 2
251 elif name == 'D2T':
252 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
253 elif name[0] == name[-1]:
254 mat = self.sparse_lib.eye(self.N)
255 else:
256 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
257 return mat
259 try:
260 mat = get_forward_conv(name)
261 except NotImplementedError as E:
262 try:
263 fwd = get_forward_conv(name[::-1])
264 import scipy.sparse as sp
266 if self.sparse_lib == sp:
267 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
268 else:
269 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
270 except NotImplementedError:
271 raise NotImplementedError from E
273 self.cache[name] = mat
274 return mat
276 def get_basis_change_matrix(self, conv='T2T', **kwargs):
277 """
278 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
279 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
280 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
281 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
283 Args:
284 conv (str): Conversion code, i.e. T2U
286 Returns:
287 Sparse conversion matrix
288 """
289 return self.get_conv(conv)
291 def get_integration_matrix(self, lbnd=0):
292 """
293 Get matrix for integration
295 Args:
296 lbnd (float): Lower bound for integration, only 0 is currently implemented
298 Returns:
299 Sparse integration matrix
300 """
301 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
302 n = self.xp.arange(self.N)
303 if lbnd == 0:
304 S = S.tocsc()
305 S[0, 1::2] = (
306 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
307 * (-1) ** (self.xp.arange(self.N // 2))
308 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
309 )
310 else:
311 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
312 return S
314 def get_differentiation_matrix(self, p=1):
315 '''
316 Keep in mind that the T2T differentiation matrix is dense.
318 Args:
319 p (int): Derivative you want to compute
321 Returns:
322 numpy.ndarray: Differentiation matrix
323 '''
324 D = self.xp.zeros((self.N, self.N))
325 for j in range(self.N):
326 for k in range(j):
327 D[k, j] = 2 * j * ((j - k) % 2)
329 D[0, :] /= 2
330 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p))
332 def get_norm(self, N=None):
333 '''
334 Get normalization for converting Chebychev coefficients and DCT
336 Args:
337 N (int, optional): Resolution
339 Returns:
340 self.xp.array: Normalization
341 '''
342 N = self.N if N is None else N
343 norm = self.xp.ones(N) / N
344 norm[0] /= 2
345 return norm
347 def get_fft_shuffle(self, forward, N):
348 """
349 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following
350 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it
351 in the complex plane. This function returns a mask to do the shuffling.
353 Args:
354 forward (bool): Whether you want the shuffle for forward transform or backward transform
355 N (int): size of the grid
357 Returns:
358 self.xp.array: Use as mask
359 """
360 xp = self.xp
361 if forward:
362 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2)
363 else:
364 mask = xp.zeros(N, dtype=int)
365 mask[: N - N % 2 : 2] = xp.arange(N // 2)
366 mask[1::2] = N - xp.arange(N // 2) - 1
367 mask[-1] = N // 2
368 return mask
370 def get_fft_shift(self, forward, N):
371 """
372 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT.
374 Args:
375 forward (bool): Whether you want the rotation for forward transform or backward transform
376 N (int): size of the grid
378 Returns:
379 self.xp.array: Rotation
380 """
381 k = self.get_wavenumbers()
382 norm = self.get_norm()
383 xp = self.xp
384 if forward:
385 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm
386 else:
387 shift = xp.exp(1j * np.pi * k / (2 * N))
388 shift[0] = 0.5
389 return shift / norm
391 def get_fft_utils(self):
392 """
393 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep
394 them cached.
395 """
396 self.fft_utils = {
397 'fwd': {},
398 'bck': {},
399 }
401 # forwards transform
402 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N)
403 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N)
405 # backwards transform
406 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N)
407 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N)
409 return self.fft_utils
411 def transform(self, u, axis=-1, **kwargs):
412 """
413 1D DCT along axis. `kwargs` will be passed on to the FFT library.
415 Args:
416 u: Data you want to transform
417 axis (int): Axis you want to transform along
419 Returns:
420 Data in spectral space
421 """
422 if self.transform_type.lower() == 'dct':
423 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm
424 elif self.transform_type.lower() == 'fft':
425 result = u.copy()
427 shuffle = [slice(0, s, 1) for s in u.shape]
428 shuffle[axis] = self.fft_utils['fwd']['shuffle']
430 v = u[(*shuffle,)]
432 V = self.fft_lib.fft(v, axis=axis, **kwargs)
434 expansion = [np.newaxis for _ in u.shape]
435 expansion[axis] = slice(0, u.shape[axis], 1)
437 V *= self.fft_utils['fwd']['shift'][(*expansion,)]
439 result.real[...] = V.real[...]
440 return result
441 else:
442 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}')
444 def itransform(self, u, axis=-1):
445 """
446 1D inverse DCT along axis.
448 Args:
449 u: Data you want to transform
450 axis (int): Axis you want to transform along
452 Returns:
453 Data in physical space
454 """
455 assert self.norm.shape[0] == u.shape[axis]
457 if self.transform_type == 'dct':
458 return self.fft_lib.idct(u / self.norm, axis=axis)
459 elif self.transform_type == 'fft':
460 result = u.copy()
462 expansion = [np.newaxis for _ in u.shape]
463 expansion[axis] = slice(0, u.shape[axis], 1)
465 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis)
467 shuffle = [slice(0, s, 1) for s in u.shape]
468 shuffle[axis] = self.fft_utils['bck']['shuffle']
469 V = v[(*shuffle,)]
471 result.real[...] = V.real[...]
472 return result
473 else:
474 raise NotImplementedError
476 def get_BC(self, kind, **kwargs):
477 """
478 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
479 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
480 set the BC.
482 Args:
483 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
484 """
485 if kind.lower() == 'integral':
486 return self.get_integ_BC_row(**kwargs)
487 elif kind.lower() == 'dirichlet':
488 return self.get_Dirichlet_BC_row(**kwargs)
489 else:
490 return super().get_BC(kind)
492 def get_integ_BC_row(self):
493 """
494 Get a row for generating integral BCs with T polynomials.
495 It returns the values of the integrals of T polynomials over the entire interval.
497 Returns:
498 self.xp.ndarray: Row to put into a matrix
499 """
500 n = self.xp.arange(self.N) + 1
501 me = self.xp.zeros_like(n).astype(float)
502 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
503 me[0] = 2.0
504 return me
506 def get_Dirichlet_BC_row(self, x):
507 """
508 Get a row for generating Dirichlet BCs at x with T polynomials.
509 It returns the values of the T polynomials at x.
511 Args:
512 x (float): Position of the boundary condition
514 Returns:
515 self.xp.ndarray: Row to put into a matrix
516 """
517 if x == -1:
518 return (-1) ** self.xp.arange(self.N)
519 elif x == 1:
520 return self.xp.ones(self.N)
521 elif x == 0:
522 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
523 n[2::4] *= -1
524 return n
525 else:
526 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
528 def get_Dirichlet_recombination_matrix(self):
529 '''
530 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
531 This makes for a good right preconditioner.
533 Returns:
534 scipy.sparse: Sparse conversion matrix
535 '''
536 N = self.N
537 sp = self.sparse_lib
538 xp = self.xp
540 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)
543class UltrasphericalHelper(ChebychevHelper):
544 """
545 This implementation follows https://doi.org/10.1137/120865458.
546 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
547 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
548 There are also conversion operators from one derivative basis to the next that are sparse.
550 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
551 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
552 """
554 def get_differentiation_matrix(self, p=1):
555 """
556 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
558 Args:
559 p (int): Order of the derivative
561 Returns:
562 sparse differentiation matrix
563 """
564 sp = self.sparse_lib
565 xp = self.xp
566 N = self.N
567 l = p
568 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l)
570 def get_S(self, lmbda):
571 """
572 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
573 https://doi.org/10.1137/120865458.
575 Args:
576 lmbda (int): Ingoing derivative base
578 Returns:
579 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
580 """
581 N = self.N
583 if lmbda == 0:
584 sp = scipy.sparse
585 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil()
586 mat[:, 0] *= 2
587 else:
588 sp = self.sparse_lib
589 xp = self.xp
590 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
591 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
592 )
594 return self.sparse_lib.csc_matrix(mat)
596 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
597 """
598 Get a conversion matrix from derivative base `p_in` to `p_out`.
600 Args:
601 p_out (int): Resulting derivative base
602 p_in (int): Ingoing derivative base
603 """
604 mat_fwd = self.sparse_lib.eye(self.N)
605 for i in range(min([p_in, p_out]), max([p_in, p_out])):
606 mat_fwd = self.get_S(i) @ mat_fwd
608 if p_out > p_in:
609 return mat_fwd
611 else:
612 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
613 import scipy.sparse as sp
615 if self.useGPU:
616 mat_fwd = mat_fwd.get()
618 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
620 return self.sparse_lib.csc_matrix(mat_bck)
622 def get_integration_matrix(self):
623 """
624 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
625 integration constant such that integration starts from x=-1.
627 Example:
629 .. code-block:: python
631 import numpy as np
632 from pySDC.helpers.spectral_helper import UltrasphericalHelper
634 N = 4
635 helper = UltrasphericalHelper(N)
636 coeffs = np.random.random(N)
637 coeffs[-1] = 0
639 poly = np.polynomial.Chebyshev(coeffs)
641 S = helper.get_integration_matrix()
642 U_hat = S @ coeffs
643 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
645 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
647 Returns:
648 sparse integration matrix
649 """
650 return self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_basis_change_matrix(
651 p_out=1, p_in=0
652 )
654 def get_integration_constant(self, u_hat, axis):
655 """
656 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
658 Args:
659 u_hat: Solution in spectral space
660 axis: Axis you want to integrate over
662 Returns:
663 Integration constant, has one less dimension than `u_hat`
664 """
665 slices = [
666 None,
667 ] * u_hat.ndim
668 slices[axis] = slice(1, u_hat.shape[axis])
669 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
672class FFTHelper(SpectralHelper1D):
673 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
674 """
675 Constructor.
676 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
677 may choose to run on GPUs via the `useGPU` argument.
679 Args:
680 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or
681 use the FFT from the library to compute the DCT
682 x0 (float, optional): Coordinate of left boundary
683 x1 (float, optional): Coordinate of right boundary
684 """
685 super().__init__(*args, x0=x0, x1=x1, **kwargs)
687 def get_1dgrid(self):
688 """
689 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
690 """
691 dx = self.L / self.N
692 return self.xp.arange(self.N) * dx + self.x0
694 def get_wavenumbers(self):
695 """
696 Be careful that this ordering is very unintuitive.
697 """
698 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
700 def get_differentiation_matrix(self, p=1):
701 """
702 This matrix is diagonal, allowing to invert concurrently.
704 Args:
705 p (int): Order of the derivative
707 Returns:
708 sparse differentiation matrix
709 """
710 k = self.get_wavenumbers()
712 if self.useGPU:
713 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
714 import scipy.sparse as sp
716 D = self.sparse_lib.diags(1j * k).get()
717 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p))
718 else:
719 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
721 def get_integration_matrix(self, p=1):
722 """
723 Get integration matrix to compute `p`-th integral over the entire domain.
725 Args:
726 p (int): Order of integral you want to compute
728 Returns:
729 sparse integration matrix
730 """
731 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
732 k[0] = 1j * self.L
733 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
735 def transform(self, u, axis=-1, **kwargs):
736 """
737 1D FFT along axis. `kwargs` are passed on to the FFT library.
739 Args:
740 u: Data you want to transform
741 axis (int): Axis you want to transform along
743 Returns:
744 transformed data
745 """
746 return self.fft_lib.fft(u, axis=axis, **kwargs)
748 def itransform(self, u, axis=-1):
749 """
750 Inverse 1D FFT.
752 Args:
753 u: Data you want to transform
754 axis (int): Axis you want to transform along
756 Returns:
757 transformed data
758 """
759 return self.fft_lib.ifft(u, axis=axis)
761 def get_BC(self, kind):
762 """
763 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
764 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
765 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
767 Args:
768 kind ('integral' or 'nyquist'): Kind of BC
770 Returns:
771 self.xp.ndarray: Boundary condition row
772 """
773 if kind.lower() == 'integral':
774 return self.get_integ_BC_row()
775 elif kind.lower() == 'nyquist':
776 assert (
777 self.N % 2 == 0
778 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
779 BC = self.xp.zeros(self.N)
780 BC[self.get_Nyquist_mode_index()] = 1
781 return BC
782 else:
783 return super().get_BC(kind)
785 def get_Nyquist_mode_index(self):
786 """
787 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
788 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
789 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
790 after.
792 Returns:
793 int: Index of the Nyquist mode
794 """
795 k = self.get_wavenumbers()
796 Nyquist_mode = min(k)
797 return self.xp.where(k == Nyquist_mode)[0][0]
799 def get_integ_BC_row(self):
800 """
801 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
802 """
803 me = self.xp.zeros(self.N)
804 me[0] = self.L / self.N
805 return me
808class SpectralHelper:
809 """
810 This class has three functions:
811 - Easily assemble matrices containing multiple equations
812 - Direct product of 1D bases to solve problems in more dimensions
813 - Distribute the FFTs to facilitate concurrency.
815 Attributes:
816 comm (mpi4py.Intracomm): MPI communicator
817 debug (bool): Perform additional checks at extra computational cost
818 useGPU (bool): Whether to use GPUs
819 axes (list): List of 1D bases
820 components (list): List of strings of the names of components in the equations
821 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
822 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
823 BCs (sparse matrix): Matrix containing only the BCs
824 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
825 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
826 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
827 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
828 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
829 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
830 local_slice (slice): Local slice of the solution as in `mpi4py-fft`
831 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
832 init (tuple): This is the same `init` that is used throughout the problem classes
833 init_forward (tuple): This is the equivalent of `init` in spectral space
834 """
836 xp = np
837 fft_lib = scipy.fft
838 sparse_lib = scipy.sparse
839 linalg = scipy.sparse.linalg
840 dtype = mesh
841 fft_backend = 'fftw'
842 fft_comm_backend = 'MPI'
844 @classmethod
845 def setup_GPU(cls):
846 """switch to GPU modules"""
847 import cupy as cp
848 import cupyx.scipy.sparse as sparse_lib
849 import cupyx.scipy.sparse.linalg as linalg
850 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
852 cls.xp = cp
853 cls.sparse_lib = sparse_lib
854 cls.linalg = linalg
856 cls.fft_backend = 'cupy'
857 cls.fft_comm_backend = 'NCCL'
859 cls.dtype = cupy_mesh
861 def __init__(self, comm=None, useGPU=False, debug=False):
862 """
863 Constructor
865 Args:
866 comm (mpi4py.Intracomm): MPI communicator
867 useGPU (bool): Whether to use GPUs
868 debug (bool): Perform additional checks at extra computational cost
869 """
870 self.comm = comm
871 self.debug = debug
872 self.useGPU = useGPU
874 if useGPU:
875 self.setup_GPU()
877 self.axes = []
878 self.components = []
880 self.full_BCs = []
881 self.BC_mat = None
882 self.BCs = None
884 self.fft_cache = {}
886 @property
887 def u_init(self):
888 """
889 Get empty data container in physical space
890 """
891 return self.dtype(self.init)
893 @property
894 def u_init_forward(self):
895 """
896 Get empty data container in spectral space
897 """
898 return self.dtype(self.init_forward)
900 @property
901 def shape(self):
902 """
903 Get shape of individual solution component
904 """
905 return self.init[0][1:]
907 @property
908 def ndim(self):
909 return len(self.axes)
911 @property
912 def ncomponents(self):
913 return len(self.components)
915 @property
916 def V(self):
917 """
918 Get domain volume
919 """
920 return np.prod([me.L for me in self.axes])
922 def add_axis(self, base, *args, **kwargs):
923 """
924 Add an axis to the domain by deciding on suitable 1D base.
925 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
926 bases for possible arguments.
928 Args:
929 base (str): 1D spectral method
930 """
931 kwargs['useGPU'] = self.useGPU
933 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
934 kwargs['transform_type'] = kwargs.get('transform_type', 'fft')
935 self.axes.append(ChebychevHelper(*args, **kwargs))
936 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
937 self.axes.append(FFTHelper(*args, **kwargs))
938 elif base.lower() in ['ultraspherical', 'gegenbauer']:
939 self.axes.append(UltrasphericalHelper(*args, **kwargs))
940 else:
941 raise NotImplementedError(f'{base=!r} is not implemented!')
942 self.axes[-1].xp = self.xp
943 self.axes[-1].sparse_lib = self.sparse_lib
945 def add_component(self, name):
946 """
947 Add solution component(s).
949 Args:
950 name (str or list of strings): Name(s) of component(s)
951 """
952 if type(name) in [list, tuple]:
953 for me in name:
954 self.add_component(me)
955 elif type(name) in [str]:
956 if name in self.components:
957 raise Exception(f'{name=!r} is already added to this problem!')
958 self.components.append(name)
959 else:
960 raise NotImplementedError
962 def index(self, name):
963 """
964 Get the index of component `name`.
966 Args:
967 name (str or list of strings): Name(s) of component(s)
969 Returns:
970 int: Index of the component
971 """
972 if type(name) in [str, int]:
973 return self.components.index(name)
974 elif type(name) in [list, tuple]:
975 return (self.index(me) for me in name)
976 else:
977 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
979 def get_empty_operator_matrix(self):
980 """
981 Return a matrix of operators to be filled with the connections between the solution components.
983 Returns:
984 list containing sparse zeros
985 """
986 S = len(self.components)
987 O = self.get_Id() * 0
988 return [[O for _ in range(S)] for _ in range(S)]
990 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
991 """
992 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
993 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
994 Typically, the last line that does not contain a BC is the best choice.
995 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
997 Args:
998 axis (int): Axis you want to add the BC to
999 kind (str): kind of BC, e.g. Dirichlet
1000 line (int): Line you want the BC to go in
1001 scalar (bool): Put the BC in all space positions in the other direction
1003 Returns:
1004 sparse matrix containing the BC
1005 """
1006 sp = scipy.sparse
1008 base = self.axes[axis]
1010 BC = sp.eye(base.N).tolil() * 0
1011 if self.useGPU:
1012 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1013 else:
1014 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1016 ndim = len(self.axes)
1017 if ndim == 1:
1018 return self.sparse_lib.csc_matrix(BC)
1019 elif ndim == 2:
1020 axis2 = (axis + 1) % ndim
1022 if scalar:
1023 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1024 else:
1025 _Id = self.axes[axis2].get_Id()
1027 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1029 if self.useGPU:
1030 Id = Id.get()
1032 mats = [
1033 None,
1034 ] * ndim
1035 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1036 mats[axis2] = Id
1037 return self.sparse_lib.csc_matrix(sp.kron(*mats))
1039 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1040 """
1041 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1042 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1043 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1045 Args:
1046 component (str): Name of the component the BC should act on
1047 equation (str): Name of the equation for the component you want to put the BC in
1048 axis (int): Axis you want to add the BC to
1049 kind (str): kind of BC, e.g. Dirichlet
1050 v: Value of the BC
1051 line (int): Line you want the BC to go in
1052 scalar (bool): Put the BC in all space positions in the other direction
1053 """
1054 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1055 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1057 if scalar:
1058 slices = [self.index(equation)] + [
1059 0,
1060 ] * self.ndim
1061 slices[axis + 1] = line
1062 else:
1063 slices = (
1064 [self.index(equation)]
1065 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1066 + [line]
1067 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1068 )
1069 N = self.axes[axis].N
1070 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1071 self.BC_rhs_mask[(*slices,)] = False
1073 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1074 """
1075 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1076 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1077 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1079 Args:
1080 component (str): Name of the component the BC should act on
1081 equation (str): Name of the equation for the component you want to put the BC in
1082 axis (int): Axis you want to add the BC to
1083 kind (str): kind of BC, e.g. Dirichlet
1084 v: Value of the BC
1085 line (int): Line you want the BC to go in
1086 scalar (bool): Put the BC in all space positions in the other direction
1087 """
1088 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1089 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1090 self.full_BCs += [
1091 {
1092 'component': component,
1093 'equation': equation,
1094 'axis': axis,
1095 'kind': kind,
1096 'v': v,
1097 'line': line,
1098 'scalar': scalar,
1099 **kwargs,
1100 }
1101 ]
1103 if scalar:
1104 slices = [self.index(equation)] + [
1105 0,
1106 ] * self.ndim
1107 slices[axis + 1] = line
1108 if self.comm:
1109 if self.comm.rank == 0:
1110 self.BC_rhs_mask[(*slices,)] = True
1111 else:
1112 self.BC_rhs_mask[(*slices,)] = True
1113 else:
1114 slices = (
1115 [self.index(equation)]
1116 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1117 + [line]
1118 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1119 )
1120 N = self.axes[axis].N
1121 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]:
1122 slices[axis + 1] -= self.local_slice[axis].start
1123 self.BC_rhs_mask[(*slices,)] = True
1125 def setup_BCs(self):
1126 """
1127 Convert the list of lists of BCs to the boundary condition operator.
1128 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1129 condition. This method sets up a suitable sparse matrix to do this.
1130 """
1131 sp = self.sparse_lib
1132 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1133 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1135 diags = self.xp.ones(self.BCs.shape[0])
1136 diags[self.BC_zero_index] = 0
1137 self.BC_line_zero_matrix = sp.diags(diags)
1139 # prepare BCs in spectral space to easily add to the RHS
1140 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1141 self.rhs_BCs_hat = self.transform(rhs_BCs)
1143 def check_BCs(self, u):
1144 """
1145 Check that the solution satisfies the boundary conditions
1147 Args:
1148 u: The solution you want to check
1149 """
1150 assert self.ndim < 3
1151 for axis in range(self.ndim):
1152 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1154 if len(BCs) > 0:
1155 u_hat = self.transform(u, axes=(axis - self.ndim,))
1156 for BC in BCs:
1157 kwargs = {
1158 key: value
1159 for key, value in BC.items()
1160 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1161 }
1163 if axis == 0:
1164 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1165 elif axis == 1:
1166 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1167 want = BC['v']
1168 assert self.xp.allclose(
1169 get, want
1170 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1172 def put_BCs_in_matrix(self, A):
1173 """
1174 Put the boundary conditions in a matrix by replacing rows with BCs.
1175 """
1176 return self.BC_line_zero_matrix @ A + self.BCs
1178 def put_BCs_in_rhs_hat(self, rhs_hat):
1179 """
1180 Put the BCs in the right hand side in spectral space for solving.
1181 This function needs no transforms and caches a mask for faster subsequent use.
1183 Args:
1184 rhs_hat: Right hand side in spectral space
1186 Returns:
1187 rhs in spectral space with BCs
1188 """
1189 if not hasattr(self, '_rhs_hat_zero_mask'):
1190 """
1191 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1192 by the boundary conditions. The mask is then cached.
1193 """
1194 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool)
1196 for axis in range(self.ndim):
1197 for bc in self.full_BCs:
1198 slices = (
1199 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1200 + [bc['line']]
1201 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1202 )
1203 if axis == bc['axis']:
1204 _slice = [self.index(bc['equation'])] + slices
1205 N = self.axes[axis].N
1206 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1207 _slice[axis + 1] -= self.local_slice[axis].start
1208 self._rhs_hat_zero_mask[(*_slice,)] = True
1210 rhs_hat[self._rhs_hat_zero_mask] = 0
1211 return rhs_hat + self.rhs_BCs_hat
1213 def put_BCs_in_rhs(self, rhs):
1214 """
1215 Put the BCs in the right hand side for solving.
1216 This function will transform along each axis individually and add all BCs in that axis.
1217 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1219 Args:
1220 rhs: Right hand side in physical space
1222 Returns:
1223 rhs in physical space with BCs
1224 """
1225 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1227 ndim = self.ndim
1229 for axis in range(ndim):
1230 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1232 for bc in self.full_BCs:
1233 slices = (
1234 [slice(0, self.init[0][i + 1]) for i in range(axis)]
1235 + [bc['line']]
1236 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1237 )
1238 if axis == bc['axis']:
1239 _slice = [self.index(bc['equation'])] + slices
1241 N = self.axes[axis].N
1242 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]:
1243 _slice[axis + 1] -= self.local_slice[axis].start
1245 _rhs_hat[(*_slice,)] = bc['v']
1247 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1249 return rhs
1251 def add_equation_lhs(self, A, equation, relations):
1252 """
1253 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1254 that you will convert to an operator later.
1256 Example:
1257 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1258 preconditioning:
1260 .. code-block:: python
1261 helper = SpectralHelper()
1263 helper.add_axis(base='chebychev', N=8)
1264 helper.add_component(['u', 'ux'])
1265 helper.setup_fft()
1267 I = helper.get_Id()
1268 Dx = helper.get_differentiation_matrix(axes=(0,))
1269 T2U = helper.get_basis_change_matrix('T2U')
1271 L_lhs = {
1272 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1273 'u': {'ux': -(T2U @ Dx)},
1274 }
1276 operator = helper.get_empty_operator_matrix()
1277 for line, equation in L_lhs.items():
1278 helper.add_equation_lhs(operator, line, equation)
1280 L = helper.convert_operator_matrix_to_operator(operator)
1282 Args:
1283 A (list of lists of sparse matrices): The operator to be
1284 equation (str): The equation of the component you want this in
1285 relations: (dict): Relations between quantities
1286 """
1287 for k, v in relations.items():
1288 A[self.index(equation)][self.index(k)] = v
1290 def convert_operator_matrix_to_operator(self, M):
1291 """
1292 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1293 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1295 Args:
1296 M (list of lists of sparse matrices): The operator to be
1298 Returns:
1299 sparse linear operator
1300 """
1301 if len(self.components) == 1:
1302 return M[0][0]
1303 else:
1304 return self.sparse_lib.bmat(M, format='csc')
1306 def get_wavenumbers(self):
1307 """
1308 Get grid in spectral space
1309 """
1310 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))][::-1]
1311 return self.xp.meshgrid(*grids)
1313 def get_grid(self):
1314 """
1315 Get grid in physical space
1316 """
1317 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))][::-1]
1318 return self.xp.meshgrid(*grids)
1320 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1321 """
1322 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1324 Args:
1325 axes (tuple): Axes you want to transform over
1326 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1327 padding (tuple): Padding for dealiasing
1328 shape (tuple): Shape of the transform
1330 Returns:
1331 transform
1332 """
1333 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1334 shape = self.global_shape[1:] if shape is None else shape
1335 padding = (
1336 [
1337 1,
1338 ]
1339 * self.ndim
1340 if padding is None
1341 else padding
1342 )
1343 key = (axes, direction, tuple(padding), tuple(shape))
1345 if key not in self.fft_cache.keys():
1346 if self.comm is None:
1347 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1349 if direction == 'forward':
1350 self.fft_cache[key] = self.xp.fft.fftn
1351 elif direction == 'backward':
1352 self.fft_cache[key] = self.xp.fft.ifftn
1353 elif direction == 'object':
1354 self.fft_cache[key] = None
1355 else:
1356 if direction == 'object':
1357 from mpi4py_fft import PFFT
1359 _fft = PFFT(
1360 comm=self.comm,
1361 shape=shape,
1362 axes=sorted(axes),
1363 dtype='D',
1364 collapse=False,
1365 backend=self.fft_backend,
1366 comm_backend=self.fft_comm_backend,
1367 padding=padding,
1368 )
1369 else:
1370 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1372 if direction == 'forward':
1373 self.fft_cache[key] = _fft.forward
1374 elif direction == 'backward':
1375 self.fft_cache[key] = _fft.backward
1376 elif direction == 'object':
1377 self.fft_cache[key] = _fft
1379 return self.fft_cache[key]
1381 def setup_fft(self, real_spectral_coefficients=False):
1382 """
1383 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1384 This must also be called before setting up any BCs.
1386 Args:
1387 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1388 """
1389 if len(self.components) == 0:
1390 self.add_component('u')
1392 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1393 self.local_slice = [slice(0, me.N) for me in self.axes]
1395 axes = tuple(i for i in range(len(self.axes)))
1396 self.fft_obj = self.get_fft(axes=axes, direction='object')
1397 if self.fft_obj is not None:
1398 self.local_slice = self.fft_obj.local_slice(False)
1400 self.init = (
1401 np.empty(shape=self.global_shape)[
1402 (
1403 ...,
1404 *self.local_slice,
1405 )
1406 ].shape,
1407 self.comm,
1408 np.dtype('float'),
1409 )
1410 self.init_forward = (
1411 np.empty(shape=self.global_shape)[
1412 (
1413 ...,
1414 *self.local_slice,
1415 )
1416 ].shape,
1417 self.comm,
1418 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1419 )
1421 self.BC_mat = self.get_empty_operator_matrix()
1422 self.BC_rhs_mask = self.xp.zeros(
1423 shape=self.init[0],
1424 dtype=bool,
1425 )
1427 def _transform_fft(self, u, axes, **kwargs):
1428 """
1429 FFT along `axes`
1431 Args:
1432 u: The solution
1433 axes (tuple): Axes you want to transform over
1435 Returns:
1436 transformed solution
1437 """
1438 # TODO: clean up and try putting more of this in the 1D bases
1439 fft = self.get_fft(axes, 'forward', **kwargs)
1440 return fft(u, axes=axes)
1442 def _transform_dct(self, u, axes, padding=None, **kwargs):
1443 '''
1444 DCT along `axes`.
1445 This will only return real values!
1446 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of
1447 wavenumbers in FFTs.
1449 Args:
1450 u: The solution
1451 axes (tuple): Axes you want to transform over
1453 Returns:
1454 transformed solution
1455 '''
1456 # TODO: clean up and try putting more of this in the 1D bases
1457 if self.debug:
1458 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1460 if len(axes) > 1:
1461 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs)
1462 else:
1463 v = u.copy().astype(complex)
1464 axis = axes[0]
1465 base = self.axes[axis]
1467 shuffle = [slice(0, s, 1) for s in u.shape]
1468 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis])
1469 v = v[(*shuffle,)]
1471 if padding is not None:
1472 shape = list(v.shape)
1473 if self.comm:
1474 send_buf = np.array(v.shape[0])
1475 recv_buf = np.array(v.shape[0])
1476 self.comm.Allreduce(send_buf, recv_buf)
1477 shape[0] = int(recv_buf)
1478 fft = self.get_fft(axes, 'forward', shape=shape)
1479 else:
1480 fft = self.get_fft(axes, 'forward', **kwargs)
1482 v = fft(v, axes=axes)
1484 expansion = [np.newaxis for _ in u.shape]
1485 expansion[axis] = slice(0, v.shape[axis], 1)
1487 if padding is not None:
1488 shift = base.get_fft_shift(True, v.shape[axis])
1490 if padding[axis] != 1:
1491 N = int(np.ceil(v.shape[axis] / padding[axis]))
1492 _expansion = [slice(0, n) for n in v.shape]
1493 _expansion[axis] = slice(0, N, 1)
1494 v = v[(*_expansion,)]
1495 else:
1496 shift = base.fft_utils['fwd']['shift']
1498 v *= shift[(*expansion,)]
1500 return v.real
1502 def transform_single_component(self, u, axes=None, padding=None):
1503 """
1504 Transform a single component of the solution
1506 Args:
1507 u data to transform:
1508 axes (tuple): Axes over which to transform
1509 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1511 Returns:
1512 Transformed data
1513 """
1514 # TODO: clean up and try putting more of this in the 1D bases
1515 trfs = {
1516 ChebychevHelper: self._transform_dct,
1517 UltrasphericalHelper: self._transform_dct,
1518 FFTHelper: self._transform_fft,
1519 }
1521 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1522 padding = (
1523 [
1524 1,
1525 ]
1526 * self.ndim
1527 if padding is None
1528 else padding
1529 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only.
1531 result = u.copy().astype(complex)
1532 alignment = self.ndim - 1
1534 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1535 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1536 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1537 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)]
1539 fft = self.get_fft(axes=axes, padding=padding, direction='object')
1540 if fft is not None:
1541 shape = list(fft.global_shape(False))
1543 for trf in range(len(axes_collapsed)):
1544 _axes = axes_collapsed[trf]
1545 base = bases[trf]
1547 if len(_axes) == 0:
1548 continue
1550 for _ax in _axes:
1551 shape[_ax] = self.global_shape[1 + self.ndim + _ax]
1553 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1555 _in = self.get_aligned(
1556 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape
1557 )
1559 alignment = self.ndim + _axes[-1]
1561 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1563 if self.comm is not None:
1564 _out *= np.prod([self.axes[i].N for i in _axes])
1566 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1567 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1]
1568 result = self.get_aligned(
1569 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape
1570 )
1572 return result
1574 def transform(self, u, axes=None, padding=None):
1575 """
1576 Transform all components from physical space to spectral space
1578 Args:
1579 u data to transform:
1580 axes (tuple): Axes over which to transform
1581 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming
1583 Returns:
1584 Transformed data
1585 """
1586 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1587 padding = (
1588 [
1589 1,
1590 ]
1591 * self.ndim
1592 if padding is None
1593 else padding
1594 )
1596 result = [
1597 None,
1598 ] * self.ncomponents
1599 for comp in self.components:
1600 i = self.index(comp)
1602 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding)
1604 return self.xp.stack(result)
1606 def _transform_ifft(self, u, axes, **kwargs):
1607 # TODO: clean up and try putting more of this in the 1D bases
1608 ifft = self.get_fft(axes, 'backward', **kwargs)
1609 return ifft(u, axes=axes)
1611 def _transform_idct(self, u, axes, padding=None, **kwargs):
1612 '''
1613 This will only ever return real values!
1614 '''
1615 # TODO: clean up and try putting more of this in the 1D bases
1616 if self.debug:
1617 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.'
1619 v = u.copy().astype(complex)
1621 if len(axes) > 1:
1622 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],))
1623 else:
1624 axis = axes[0]
1625 base = self.axes[axis]
1627 if padding is not None:
1628 if padding[axis] != 1:
1629 N_pad = int(np.ceil(v.shape[axis] * padding[axis]))
1630 _pad = [[0, 0] for _ in v.shape]
1631 _pad[axis] = [0, N_pad - base.N]
1632 v = self.xp.pad(v, _pad, 'constant')
1634 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N
1635 else:
1636 shift = base.fft_utils['bck']['shift']
1637 else:
1638 shift = base.fft_utils['bck']['shift']
1640 expansion = [np.newaxis for _ in u.shape]
1641 expansion[axis] = slice(0, v.shape[axis], 1)
1643 v *= shift[(*expansion,)]
1645 if padding is not None:
1646 if padding[axis] != 1:
1647 shape = list(v.shape)
1648 if self.comm:
1649 send_buf = np.array(v.shape[0])
1650 recv_buf = np.array(v.shape[0])
1651 self.comm.Allreduce(send_buf, recv_buf)
1652 shape[0] = int(recv_buf)
1653 ifft = self.get_fft(axes, 'backward', shape=shape)
1654 else:
1655 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1656 else:
1657 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs)
1658 v = ifft(v, axes=axes)
1660 shuffle = [slice(0, s, 1) for s in v.shape]
1661 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis])
1662 v = v[(*shuffle,)]
1664 return v.real
1666 def itransform_single_component(self, u, axes=None, padding=None):
1667 """
1668 Inverse transform over single component of the solution
1670 Args:
1671 u data to transform:
1672 axes (tuple): Axes over which to transform
1673 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1675 Returns:
1676 Transformed data
1677 """
1678 # TODO: clean up and try putting more of this in the 1D bases
1679 trfs = {
1680 FFTHelper: self._transform_ifft,
1681 ChebychevHelper: self._transform_idct,
1682 UltrasphericalHelper: self._transform_idct,
1683 }
1685 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1686 padding = (
1687 [
1688 1,
1689 ]
1690 * self.ndim
1691 if padding is None
1692 else padding
1693 )
1695 result = u.copy().astype(complex)
1696 alignment = self.ndim - 1
1698 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()]
1699 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0]
1700 axes_collapsed = [me for me in axes_collapsed if len(me) > 0]
1701 shape = list(self.global_shape[1:])
1703 for trf in range(len(axes_collapsed)):
1704 _axes = axes_collapsed[trf]
1705 base = bases[trf]
1707 if len(_axes) == 0:
1708 continue
1710 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape)
1712 _in = self.get_aligned(
1713 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape
1714 )
1715 if self.comm is not None:
1716 _in /= np.prod([self.axes[i].N for i in _axes])
1718 alignment = self.ndim + _axes[0]
1720 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape)
1722 for _ax in _axes:
1723 if fft:
1724 shape[_ax] = fft._input_shape[_ax]
1725 else:
1726 shape[_ax] = _out.shape[_ax]
1728 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1]
1729 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0]
1730 result = self.get_aligned(
1731 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape
1732 )
1734 return result
1736 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs):
1737 """
1738 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT
1739 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include
1740 shape especially, if applicable.
1742 Args:
1743 u: The solution
1744 axis_in (int): Current alignment
1745 axis_out (int): New alignment
1746 fft (mpi4py_fft.PFFT), optional: parallel FFT object
1747 forward (bool): Whether the input is in spectral space or not
1749 Returns:
1750 solution aligned on `axis_in`
1751 """
1752 if self.comm is None or axis_in == axis_out:
1753 return u.copy()
1754 if self.comm.size == 1:
1755 return u.copy()
1757 fft = self.get_fft(**kwargs) if fft is None else fft
1759 global_fft = self.get_fft(**kwargs)
1760 axisA = [me.axisA for me in global_fft.transfer]
1761 axisB = [me.axisB for me in global_fft.transfer]
1763 current_axis = axis_in
1765 if axis_in in axisA and axis_out in axisB:
1766 while current_axis != axis_out:
1767 transfer = global_fft.transfer[axisA.index(current_axis)]
1769 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1770 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1771 arrayA[:] = u[:]
1773 transfer.forward(arrayA=arrayA, arrayB=arrayB)
1775 current_axis = transfer.axisB
1776 u = arrayB
1778 return u
1779 elif axis_in in axisB and axis_out in axisA:
1780 while current_axis != axis_out:
1781 transfer = global_fft.transfer[axisB.index(current_axis)]
1783 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype)
1784 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype)
1785 arrayB[:] = u[:]
1787 transfer.backward(arrayA=arrayA, arrayB=arrayB)
1789 current_axis = transfer.axisA
1790 u = arrayA
1792 return u
1793 else: # go the potentially slower route of not reusing transfer classes
1794 from mpi4py_fft import newDistArray
1796 _in = newDistArray(fft, forward).redistribute(axis_in)
1797 _in[...] = u
1799 return _in.redistribute(axis_out)
1801 def itransform(self, u, axes=None, padding=None):
1802 """
1803 Inverse transform over all components of the solution
1805 Args:
1806 u data to transform:
1807 axes (tuple): Axes over which to transform
1808 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming
1810 Returns:
1811 Transformed data
1812 """
1813 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes
1814 padding = (
1815 [
1816 1,
1817 ]
1818 * self.ndim
1819 if padding is None
1820 else padding
1821 )
1823 result = [
1824 None,
1825 ] * self.ncomponents
1826 for comp in self.components:
1827 i = self.index(comp)
1829 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding)
1831 return self.xp.stack(result)
1833 def get_local_slice_of_1D_matrix(self, M, axis):
1834 """
1835 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1836 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can
1837 use this method to get the part corresponding to the modes carried by this rank.
1839 Args:
1840 M (sparse matrix): Global 1D matrix you want to get the local version of
1841 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only.
1843 Returns:
1844 sparse local matrix
1845 """
1846 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]
1848 def get_filter_matrix(self, axis, **kwargs):
1849 """
1850 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
1851 admissible.
1853 Returns:
1854 sparse bandpass matrix
1855 """
1856 if self.ndim == 1:
1857 return self.axes[0].get_filter_matrix(**kwargs)
1859 mats = [base.get_Id() for base in self.axes]
1860 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
1861 return self.sparse_lib.kron(*mats)
1863 def get_differentiation_matrix(self, axes, **kwargs):
1864 """
1865 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
1867 Args:
1868 axes (tuple): Axes along which to differentiate.
1870 Returns:
1871 sparse differentiation matrix
1872 """
1873 sp = self.sparse_lib
1874 ndim = self.ndim
1876 if ndim == 1:
1877 D = self.axes[0].get_differentiation_matrix(**kwargs)
1878 elif ndim == 2:
1879 for axis in axes:
1880 axis2 = (axis + 1) % ndim
1881 D1D = self.axes[axis].get_differentiation_matrix(**kwargs)
1883 if len(axes) > 1:
1884 I1D = sp.eye(self.axes[axis2].N)
1885 else:
1886 I1D = self.axes[axis2].get_Id()
1888 mats = [None] * ndim
1889 mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis)
1890 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1892 if axis == axes[0]:
1893 D = sp.kron(*mats)
1894 else:
1895 D = D @ sp.kron(*mats)
1896 else:
1897 raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!')
1899 return D
1901 def get_integration_matrix(self, axes):
1902 """
1903 Get integration matrix to integrate along specified axis.
1905 Args:
1906 axes (tuple): Axes along which to integrate over.
1908 Returns:
1909 sparse integration matrix
1910 """
1911 sp = self.sparse_lib
1912 ndim = len(self.axes)
1914 if ndim == 1:
1915 S = self.axes[0].get_integration_matrix()
1916 elif ndim == 2:
1917 for axis in axes:
1918 axis2 = (axis + 1) % ndim
1919 S1D = self.axes[axis].get_integration_matrix()
1921 if len(axes) > 1:
1922 I1D = sp.eye(self.axes[axis2].N)
1923 else:
1924 I1D = self.axes[axis2].get_Id()
1926 mats = [None] * ndim
1927 mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis)
1928 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1930 if axis == axes[0]:
1931 S = sp.kron(*mats)
1932 else:
1933 S = S @ sp.kron(*mats)
1934 else:
1935 raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!')
1937 return S
1939 def get_Id(self):
1940 """
1941 Get identity matrix
1943 Returns:
1944 sparse identity matrix
1945 """
1946 sp = self.sparse_lib
1947 ndim = self.ndim
1948 I = sp.eye(np.prod(self.init[0][1:]), dtype=complex)
1950 if ndim == 1:
1951 I = self.axes[0].get_Id()
1952 elif ndim == 2:
1953 for axis in range(ndim):
1954 axis2 = (axis + 1) % ndim
1955 I1D = self.axes[axis].get_Id()
1957 I1D2 = sp.eye(self.axes[axis2].N)
1959 mats = [None] * ndim
1960 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1961 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2)
1963 I = I @ sp.kron(*mats)
1964 else:
1965 raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!')
1967 return I
1969 def get_Dirichlet_recombination_matrix(self, axis=-1):
1970 """
1971 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
1973 Args:
1974 axis (int): Axis you discretized with Chebychev
1976 Returns:
1977 sparse matrix
1978 """
1979 sp = self.sparse_lib
1980 ndim = len(self.axes)
1982 if ndim == 1:
1983 C = self.axes[0].get_Dirichlet_recombination_matrix()
1984 elif ndim == 2:
1985 axis2 = (axis + 1) % ndim
1986 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
1988 I1D = self.axes[axis2].get_Id()
1990 mats = [None] * ndim
1991 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
1992 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
1994 C = sp.kron(*mats)
1995 else:
1996 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
1998 return C
2000 def get_basis_change_matrix(self, axes=None, **kwargs):
2001 """
2002 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
2003 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
2005 Args:
2006 axes (tuple): Axes along which to change basis.
2008 Returns:
2009 sparse basis change matrix
2010 """
2011 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
2013 sp = self.sparse_lib
2014 ndim = len(self.axes)
2016 if ndim == 1:
2017 C = self.axes[0].get_basis_change_matrix(**kwargs)
2018 elif ndim == 2:
2019 for axis in axes:
2020 axis2 = (axis + 1) % ndim
2021 C1D = self.axes[axis].get_basis_change_matrix(**kwargs)
2023 if len(axes) > 1:
2024 I1D = sp.eye(self.axes[axis2].N)
2025 else:
2026 I1D = self.axes[axis2].get_Id()
2028 mats = [None] * ndim
2029 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
2030 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
2032 if axis == axes[0]:
2033 C = sp.kron(*mats)
2034 else:
2035 C = C @ sp.kron(*mats)
2036 else:
2037 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
2039 return C