Coverage for pySDC/helpers/spectral_helper.py: 83%
865 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-12 05:46 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-12 05:46 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
5from functools import partial, wraps
6import logging
9def cache(func):
10 """
11 Decorator for caching return values of functions.
12 This is very similar to `functools.cache`, but without the memory leaks (see
13 https://docs.astral.sh/ruff/rules/cached-instance-method/).
15 Example:
17 .. code-block:: python
19 num_calls = 0
21 @cache
22 def increment(x):
23 num_calls += 1
24 return x + 1
26 increment(0) # returns 1, num_calls = 1
27 increment(1) # returns 2, num_calls = 2
28 increment(0) # returns 1, num_calls = 2
31 Args:
32 func (function): The function you want to cache the return value of
34 Returns:
35 return value of func
36 """
37 attr_cache = f"_{func.__name__}_cache"
39 @wraps(func)
40 def wrapper(self, *args, **kwargs):
41 if not hasattr(self, attr_cache):
42 setattr(self, attr_cache, {})
44 cache = getattr(self, attr_cache)
46 key = (args, frozenset(kwargs.items()))
47 if key in cache:
48 return cache[key]
49 result = func(self, *args, **kwargs)
50 cache[key] = result
51 return result
53 return wrapper
56class vkFFT(object):
57 """
58 pyVkFFT FFT backend.
59 The special feature of vkFFT is fast DCT on GPU with cached plans.
60 """
62 cached_plans = {}
64 @staticmethod
65 def is_complex(x):
66 return 'complex' in str(x.dtype)
68 @staticmethod
69 def get_plan(transform_type, shape, dtype, axes, norm):
70 from pyvkfft.cuda import VkFFTApp
72 assert norm == 'backward'
74 key = f'{transform_type=}, {shape=}, {dtype=}, {axes=}, {norm=}'
76 if key not in vkFFT.cached_plans.keys():
78 kwargs = {}
80 if transform_type == 'dct':
81 kwargs['dct'] = 2
83 vkFFT.cached_plans[key] = VkFFTApp(shape, dtype, len(axes), axes=axes, norm=norm, **kwargs)
85 logger = logging.getLogger(name='VkFFT')
86 logger.debug(f'Cached plan for VkFFT: {key}')
87 return vkFFT.cached_plans[key]
89 @staticmethod
90 def fftn(x, s=None, axes=None, norm='backward', overwrite_x=False):
91 assert not overwrite_x # for consistent interface with scipy
92 assert norm == 'backward' # for consistent interface with scipy
93 plan = vkFFT.get_plan(
94 transform_type='fft',
95 shape=x.shape,
96 dtype=x.dtype,
97 axes=axes,
98 norm=norm,
99 )
100 _x = x.copy() + 0j # cast to complex
101 plan.fft(_x)
102 return _x
104 @staticmethod
105 def ifftn(x, s=None, axes=None, norm='forward', overwrite_x=False):
106 assert norm == 'forward'
107 assert not overwrite_x # for consistent interface with scipy
109 norm = 'backward'
110 plan = vkFFT.get_plan(
111 transform_type='fft',
112 shape=x.shape,
113 dtype=x.dtype,
114 axes=axes,
115 norm=norm,
116 )
117 _x = x.copy() + 0j # promote to complex
118 plan.ifft(_x)
119 return _x * sum(x.shape[i] for i in axes)
121 @staticmethod
122 def dctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False):
123 assert type == 2 # for consistent interface with scipy
124 assert not overwrite_x # for consistent interface with scipy
126 is_complex = vkFFT.is_complex(x)
128 dtype = x.dtype if not is_complex else x.real.dtype
130 plan = vkFFT.get_plan(
131 transform_type='dct',
132 shape=x.shape,
133 dtype=dtype,
134 axes=axes,
135 norm=norm,
136 )
138 if is_complex:
139 x_real = x.real.copy()
140 x_imag = x.imag.copy()
142 plan.fft(x_real)
143 plan.fft(x_imag)
145 return x_real + 1j * x_imag
146 else:
147 _x = x.copy()
148 plan.fft(x)
149 return x
151 @staticmethod
152 def idctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False):
153 assert type == 2 # for consistent interface with scipy
154 assert not overwrite_x # for consistent interface with scipy
156 is_complex = vkFFT.is_complex(x)
157 dtype = x.dtype if not is_complex else x.real.dtype
159 plan = vkFFT.get_plan(
160 transform_type='dct',
161 shape=x.shape,
162 dtype=dtype,
163 axes=axes,
164 norm=norm,
165 )
167 if is_complex:
168 x_real = x.real.copy()
169 x_imag = x.imag.copy()
171 plan.ifft(x_real)
172 plan.ifft(x_imag)
174 return x_real + 1j * x_imag
175 else:
176 _x = x.copy()
177 plan.ifft(x)
178 return x
181class SpectralHelper1D:
182 """
183 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
184 all bases need to have.
186 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
187 the code for GPUs.
189 Attributes:
190 N (int): Resolution
191 x0 (float): Coordinate of left boundary
192 x1 (float): Coordinate of right boundary
193 L (float): Length of the domain
194 useGPU (bool): Whether to use GPUs
196 """
198 fft_lib = scipy.fft
199 sparse_lib = scipy.sparse
200 linalg = scipy.sparse.linalg
201 xp = np
202 distributable = False
204 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False):
205 """
206 Constructor
208 Args:
209 N (int): Resolution
210 x0 (float): Coordinate of left boundary
211 x1 (float): Coordinate of right boundary
212 useGPU (bool): Whether to use GPUs
213 useFFTW (bool): Whether to use FFTW for the transforms
214 """
215 self.N = N
216 self.x0 = x0
217 self.x1 = x1
218 self.L = x1 - x0
219 self.useGPU = useGPU
220 self.plans = {}
221 self.logger = logging.getLogger(name=type(self).__name__)
223 if useGPU:
224 self.setup_GPU()
225 self.logger.debug('Set up for GPU')
226 else:
227 self.setup_CPU(useFFTW=useFFTW)
229 if useGPU and useFFTW:
230 raise ValueError('Please run either on GPUs or with FFTW, not both!')
232 @classmethod
233 def setup_GPU(cls):
234 """switch to GPU modules"""
235 import cupy as cp
236 import cupyx.scipy.sparse as sparse_lib
237 import cupyx.scipy.sparse.linalg as linalg
238 import cupyx.scipy.fft as fft_lib
239 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
241 cls.xp = cp
242 cls.sparse_lib = sparse_lib
243 cls.linalg = linalg
244 cls.fft_lib = fft_lib
246 @classmethod
247 def setup_CPU(cls, useFFTW=False):
248 """switch to CPU modules"""
250 cls.xp = np
251 cls.sparse_lib = scipy.sparse
252 cls.linalg = scipy.sparse.linalg
254 if useFFTW:
255 from mpi4py_fft import fftw
257 cls.fft_backend = 'fftw'
258 cls.fft_lib = fftw
259 else:
260 cls.fft_backend = 'scipy'
261 cls.fft_lib = scipy.fft
263 cls.fft_comm_backend = 'MPI'
264 cls.dtype = mesh
266 def get_Id(self):
267 """
268 Get identity matrix
270 Returns:
271 sparse diagonal identity matrix
272 """
273 return self.sparse_lib.eye(self.N)
275 def get_zero(self):
276 """
277 Get a matrix with all zeros of the correct size.
279 Returns:
280 sparse matrix with zeros everywhere
281 """
282 return 0 * self.get_Id()
284 def get_differentiation_matrix(self):
285 raise NotImplementedError()
287 def get_integration_matrix(self):
288 raise NotImplementedError()
290 def get_integration_weights(self):
291 """Weights for integration across entire domain"""
292 raise NotImplementedError()
294 def get_wavenumbers(self):
295 """
296 Get the grid in spectral space
297 """
298 raise NotImplementedError
300 def get_empty_operator_matrix(self, S, O):
301 """
302 Return a matrix of operators to be filled with the connections between the solution components.
304 Args:
305 S (int): Number of components in the solution
306 O (sparse matrix): Zero matrix used for initialization
308 Returns:
309 list of lists containing sparse zeros
310 """
311 return [[O for _ in range(S)] for _ in range(S)]
313 def get_basis_change_matrix(self, *args, **kwargs):
314 """
315 Some spectral discretization change the basis during differentiation. This method can be used to transfer
316 between the various bases.
318 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
319 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
320 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
321 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
322 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
323 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
324 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
326 Returns:
327 sparse bases change matrix
328 """
329 return self.sparse_lib.eye(self.N)
331 def get_BC(self, kind):
332 """
333 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
334 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
335 the matrix are replaced by the boundary conditions as obtained by this method.
337 Args:
338 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
339 individual 1D bases for what is implemented
341 Returns:
342 self.xp.array: Boundary condition
343 """
344 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
346 def get_filter_matrix(self, kmin=0, kmax=None):
347 """
348 Get a bandpass filter.
350 Args:
351 kmin (int): Lower limit of the bandpass filter
352 kmax (int): Upper limit of the bandpass filter
354 Returns:
355 sparse matrix
356 """
358 k = abs(self.get_wavenumbers())
360 kmax = max(k) if kmax is None else kmax
362 mask = self.xp.logical_or(k >= kmax, k < kmin)
364 if self.useGPU:
365 Id = self.get_Id().get()
366 else:
367 Id = self.get_Id()
368 F = Id.tolil()
369 F[:, mask] = 0
370 return F.tocsc()
372 def get_1dgrid(self):
373 """
374 Get the grid in physical space
376 Returns:
377 self.xp.array: Grid
378 """
379 raise NotImplementedError
382class ChebychevHelper(SpectralHelper1D):
383 """
384 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
385 between physical and spectral space by discrete cosine transform.
386 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
387 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
388 be formulated in first order formulation.
390 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068).
391 """
393 def __init__(self, *args, x0=-1, x1=1, **kwargs):
394 """
395 Constructor.
396 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
397 may choose to run on GPUs via the `useGPU` argument.
399 Args:
400 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
401 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
402 """
403 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
404 self.lin_trf_fac = (x1 - x0) / 2
405 self.lin_trf_off = (x1 + x0) / 2
406 super().__init__(*args, x0=x0, x1=x1, **kwargs)
408 self.norm = self.get_norm()
410 def get_1dgrid(self):
411 '''
412 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
413 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
414 need to do an affine transformation before any Chebychev business.
416 Returns:
417 numpy.ndarray: 1D grid
418 '''
419 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
421 def get_wavenumbers(self):
422 """Get the domain in spectral space"""
423 return self.xp.arange(self.N)
425 @cache
426 def get_conv(self, name, N=None):
427 '''
428 Get conversion matrix between different kinds of polynomials. The supported kinds are
429 - T: Chebychev polynomials of first kind
430 - U: Chebychev polynomials of second kind
431 - D: Dirichlet recombination.
433 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
434 Once generates matrices are cached. So feel free to call the method as often as you like.
436 Args:
437 name (str): Conversion code, e.g. 'T2U'
438 N (int): Size of the matrix (optional)
440 Returns:
441 scipy.sparse: Sparse conversion matrix
442 '''
443 N = N if N else self.N
444 sp = self.sparse_lib
446 def get_forward_conv(name):
447 if name == 'T2U':
448 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0
449 mat[:, 0] *= 2
450 elif name == 'D2T':
451 mat = sp.eye(N) - sp.eye(N, k=2)
452 elif name[0] == name[-1]:
453 mat = self.sparse_lib.eye(self.N)
454 else:
455 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
456 return mat
458 try:
459 mat = get_forward_conv(name)
460 except NotImplementedError as E:
461 try:
462 fwd = get_forward_conv(name[::-1])
463 import scipy.sparse as sp
465 if self.sparse_lib == sp:
466 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
467 else:
468 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
469 except NotImplementedError:
470 raise NotImplementedError from E
472 return mat
474 def get_basis_change_matrix(self, conv='T2T', **kwargs):
475 """
476 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
477 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
478 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
479 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
481 Args:
482 conv (str): Conversion code, i.e. T2U
484 Returns:
485 Sparse conversion matrix
486 """
487 return self.get_conv(conv)
489 def get_integration_matrix(self, lbnd=0):
490 """
491 Get matrix for integration
493 Args:
494 lbnd (float): Lower bound for integration, only 0 is currently implemented
496 Returns:
497 Sparse integration matrix
498 """
499 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
500 n = self.xp.arange(self.N)
501 if lbnd == 0:
502 S = S.tocsc()
503 S[0, 1::2] = (
504 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
505 * (-1) ** (self.xp.arange(self.N // 2))
506 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
507 ) * self.lin_trf_fac
508 else:
509 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
510 return S
512 def get_integration_weights(self):
513 """Weights for integration across entire domain"""
514 n = self.xp.arange(self.N, dtype=float)
516 weights = (-1) ** n + 1
517 weights[2:] /= 1 - (n**2)[2:]
519 weights /= 2 / self.L
520 return weights
522 def get_differentiation_matrix(self, p=1):
523 '''
524 Keep in mind that the T2T differentiation matrix is dense.
526 Args:
527 p (int): Derivative you want to compute
529 Returns:
530 numpy.ndarray: Differentiation matrix
531 '''
532 D = self.xp.zeros((self.N, self.N))
533 for j in range(self.N):
534 for k in range(j):
535 D[k, j] = 2 * j * ((j - k) % 2)
537 D[0, :] /= 2
538 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
540 @cache
541 def get_norm(self, N=None):
542 '''
543 Get normalization for converting Chebychev coefficients and DCT
545 Args:
546 N (int, optional): Resolution
548 Returns:
549 self.xp.array: Normalization
550 '''
551 N = self.N if N is None else N
552 norm = self.xp.ones(N) / N
553 norm[0] /= 2
554 return norm
556 def transform(self, u, *args, axes=None, shape=None, **kwargs):
557 """
558 DCT along axes. `kwargs` will be passed on to the FFT library.
560 Args:
561 u: Data you want to transform
562 axes (tuple): Axes you want to transform along
564 Returns:
565 Data in spectral space
566 """
567 axes = axes if axes else tuple(i for i in range(u.ndim))
568 kwargs['s'] = shape
569 kwargs['norm'] = kwargs.get('norm', 'backward')
571 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs)
572 for axis in axes:
574 if self.N < trf.shape[axis]:
575 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
576 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies
577 # in the middle.
578 _trf = self.xp.zeros_like(trf)
579 N = self.N
580 N_pad = _trf.shape[axis] - N
581 end_first_half = N // 2 + 1
583 # copy first "half"
584 su = [slice(None)] * trf.ndim
585 su[axis] = slice(0, end_first_half)
586 _trf[tuple(su)] = trf[tuple(su)]
588 # copy second "half"
589 su = [slice(None)] * u.ndim
590 su[axis] = slice(end_first_half + N_pad, None)
591 s_u = [slice(None)] * u.ndim
592 s_u[axis] = slice(end_first_half, N)
593 _trf[tuple(su)] = trf[tuple(s_u)]
595 # # copy values to be cut
596 # su = [slice(None)] * u.ndim
597 # su[axis] = slice(end_first_half, end_first_half + N_pad)
598 # s_u = [slice(None)] * u.ndim
599 # s_u[axis] = slice(-N_pad, None)
600 # _trf[tuple(su)] = trf[tuple(s_u)]
602 trf = _trf
604 expansion = [np.newaxis for _ in u.shape]
605 expansion[axis] = slice(0, u.shape[axis], 1)
606 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1]
607 norm[: self.N] = self.norm
608 trf *= norm[(*expansion,)]
609 return trf
611 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
612 """
613 Inverse DCT along axis.
615 Args:
616 u: Data you want to transform
617 axes (tuple): Axes you want to transform along
619 Returns:
620 Data in physical space
621 """
622 axes = axes if axes else tuple(i for i in range(u.ndim))
623 kwargs['s'] = shape
624 kwargs['norm'] = kwargs.get('norm', 'backward')
625 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False)
627 for axis in axes:
629 if self.N == u.shape[axis]:
630 _u = u.copy()
631 else:
632 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
633 # added in the middle rather than the end. We need to resort this here and put the padding in the end.
634 N = self.N
635 _u = self.xp.zeros_like(u)
637 # copy first half
638 su = [slice(None)] * u.ndim
639 su[axis] = slice(0, N // 2 + 1)
640 _u[tuple(su)] = u[tuple(su)]
642 # copy second half
643 su = [slice(None)] * u.ndim
644 su[axis] = slice(-(N // 2), None)
645 s_u = [slice(None)] * u.ndim
646 s_u[axis] = slice(N // 2, N // 2 + (N // 2))
647 _u[tuple(s_u)] = u[tuple(su)]
649 if N % 2 == 0:
650 su = [slice(None)] * u.ndim
651 su[axis] = N // 2
652 _u[tuple(su)] *= 2
654 # generate norm
655 expansion = [np.newaxis for _ in u.shape]
656 expansion[axis] = slice(0, u.shape[axis], 1)
657 norm = self.xp.ones(_u.shape[axis])
658 norm[: self.N] = self.norm
659 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N
661 _u /= norm[(*expansion,)]
663 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs)
665 def get_BC(self, kind, **kwargs):
666 """
667 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
668 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
669 set the BC.
671 Args:
672 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
673 """
674 if kind.lower() == 'integral':
675 return self.get_integ_BC_row(**kwargs)
676 elif kind.lower() == 'dirichlet':
677 return self.get_Dirichlet_BC_row(**kwargs)
678 elif kind.lower() == 'neumann':
679 return self.get_Neumann_BC_row(**kwargs)
680 else:
681 return super().get_BC(kind)
683 def get_integ_BC_row(self):
684 """
685 Get a row for generating integral BCs with T polynomials.
686 It returns the values of the integrals of T polynomials over the entire interval.
688 Returns:
689 self.xp.ndarray: Row to put into a matrix
690 """
691 n = self.xp.arange(self.N) + 1
692 me = self.xp.zeros_like(n).astype(float)
693 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
694 me[0] = 2.0
695 return me
697 def get_Dirichlet_BC_row(self, x):
698 """
699 Get a row for generating Dirichlet BCs at x with T polynomials.
700 It returns the values of the T polynomials at x.
702 Args:
703 x (float): Position of the boundary condition
705 Returns:
706 self.xp.ndarray: Row to put into a matrix
707 """
708 if x == -1:
709 return (-1) ** self.xp.arange(self.N)
710 elif x == 1:
711 return self.xp.ones(self.N)
712 elif x == 0:
713 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
714 n[2::4] *= -1
715 return n
716 else:
717 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
719 def get_Neumann_BC_row(self, x):
720 """
721 Get a row for generating Neumann BCs at x with T polynomials.
723 Args:
724 x (float): Position of the boundary condition
726 Returns:
727 self.xp.ndarray: Row to put into a matrix
728 """
729 n = self.xp.arange(self.N, dtype='D')
730 nn = n**2
731 if x == -1:
732 me = nn
733 me[1:] *= (-1) ** n[:-1]
734 return me
735 elif x == 1:
736 return nn
737 else:
738 raise NotImplementedError(f'Don\'t know how to generate Neumann BC\'s at {x=}!')
740 def get_Dirichlet_recombination_matrix(self):
741 '''
742 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
743 This makes for a good right preconditioner.
745 Returns:
746 scipy.sparse: Sparse conversion matrix
747 '''
748 N = self.N
749 sp = self.sparse_lib
751 return sp.eye(N) - sp.eye(N, k=2)
754class UltrasphericalHelper(ChebychevHelper):
755 """
756 This implementation follows https://doi.org/10.1137/120865458.
757 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
758 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
759 There are also conversion operators from one derivative basis to the next that are sparse.
761 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
762 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
763 """
765 def get_differentiation_matrix(self, p=1):
766 """
767 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
769 Args:
770 p (int): Order of the derivative
772 Returns:
773 sparse differentiation matrix
774 """
775 sp = self.sparse_lib
776 xp = self.xp
777 N = self.N
778 l = p
779 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
781 def get_S(self, lmbda):
782 """
783 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
784 https://doi.org/10.1137/120865458.
786 Args:
787 lmbda (int): Ingoing derivative base
789 Returns:
790 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
791 """
792 N = self.N
794 if lmbda == 0:
795 sp = scipy.sparse
796 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil()
797 mat[:, 0] *= 2
798 else:
799 sp = self.sparse_lib
800 xp = self.xp
801 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
802 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
803 )
805 return self.sparse_lib.csc_matrix(mat)
807 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
808 """
809 Get a conversion matrix from derivative base `p_in` to `p_out`.
811 Args:
812 p_out (int): Resulting derivative base
813 p_in (int): Ingoing derivative base
814 """
815 mat_fwd = self.sparse_lib.eye(self.N)
816 for i in range(min([p_in, p_out]), max([p_in, p_out])):
817 mat_fwd = self.get_S(i) @ mat_fwd
819 if p_out > p_in:
820 return mat_fwd
822 else:
823 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
824 import scipy.sparse as sp
826 if self.useGPU:
827 mat_fwd = mat_fwd.get()
829 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
831 return self.sparse_lib.csc_matrix(mat_bck)
833 def get_integration_matrix(self):
834 """
835 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
836 integration constant such that integration starts from x=-1.
838 Example:
840 .. code-block:: python
842 import numpy as np
843 from pySDC.helpers.spectral_helper import UltrasphericalHelper
845 N = 4
846 helper = UltrasphericalHelper(N)
847 coeffs = np.random.random(N)
848 coeffs[-1] = 0
850 poly = np.polynomial.Chebyshev(coeffs)
852 S = helper.get_integration_matrix()
853 U_hat = S @ coeffs
854 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
856 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
858 Returns:
859 sparse integration matrix
860 """
861 return (
862 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
863 @ self.get_basis_change_matrix(p_out=1, p_in=0)
864 * self.lin_trf_fac
865 )
867 def get_integration_constant(self, u_hat, axis):
868 """
869 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
871 Args:
872 u_hat: Solution in spectral space
873 axis: Axis you want to integrate over
875 Returns:
876 Integration constant, has one less dimension than `u_hat`
877 """
878 slices = [
879 None,
880 ] * u_hat.ndim
881 slices[axis] = slice(1, u_hat.shape[axis])
882 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
885class FFTHelper(SpectralHelper1D):
886 distributable = True
888 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
889 """
890 Constructor.
891 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
892 may choose to run on GPUs via the `useGPU` argument.
894 Args:
895 x0 (float, optional): Coordinate of left boundary
896 x1 (float, optional): Coordinate of right boundary
897 """
898 super().__init__(*args, x0=x0, x1=x1, **kwargs)
900 def get_1dgrid(self):
901 """
902 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
903 """
904 dx = self.L / self.N
905 return self.xp.arange(self.N) * dx + self.x0
907 def get_wavenumbers(self):
908 """
909 Be careful that this ordering is very unintuitive.
910 """
911 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
913 def get_differentiation_matrix(self, p=1):
914 """
915 This matrix is diagonal, allowing to invert concurrently.
917 Args:
918 p (int): Order of the derivative
920 Returns:
921 sparse differentiation matrix
922 """
923 k = self.get_wavenumbers()
925 if self.useGPU:
926 if p > 1:
927 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
928 from scipy.sparse.linalg import matrix_power
930 D = self.sparse_lib.diags(1j * k).get()
931 return self.sparse_lib.csc_matrix(matrix_power(D, p))
932 else:
933 return self.sparse_lib.diags(1j * k)
934 else:
935 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
937 def get_integration_matrix(self, p=1):
938 """
939 Get integration matrix to compute `p`-th integral over the entire domain.
941 Args:
942 p (int): Order of integral you want to compute
944 Returns:
945 sparse integration matrix
946 """
947 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
948 k[0] = 1j * self.L
949 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
951 def get_integration_weights(self):
952 """Weights for integration across entire domain"""
953 weights = self.xp.zeros(self.N)
954 weights[0] = self.L / self.N
955 return weights
957 def get_plan(self, u, forward, *args, **kwargs):
958 if self.fft_lib.__name__ == 'mpi4py_fft.fftw':
959 if 'axes' in kwargs.keys():
960 kwargs['axes'] = tuple(kwargs['axes'])
961 key = (forward, u.shape, args, *(me for me in kwargs.values()))
962 if key in self.plans.keys():
963 return self.plans[key]
964 else:
965 self.logger.debug(f'Generating FFT plan for {key=}')
966 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs)
967 self.plans[key] = transform
969 return self.plans[key]
970 else:
971 if forward:
972 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward'))
973 else:
974 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward'))
976 def transform(self, u, *args, axes=None, shape=None, **kwargs):
977 """
978 FFT along axes. `kwargs` are passed on to the FFT library.
980 Args:
981 u: Data you want to transform
982 axes (tuple): Axes you want to transform over
984 Returns:
985 transformed data
986 """
987 axes = axes if axes else tuple(i for i in range(u.ndim))
988 kwargs['s'] = shape
989 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs)
990 return plan(u, *args, axes=axes, **kwargs)
992 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
993 """
994 Inverse FFT.
996 Args:
997 u: Data you want to transform
998 axes (tuple): Axes over which to transform
1000 Returns:
1001 transformed data
1002 """
1003 axes = axes if axes else tuple(i for i in range(u.ndim))
1004 kwargs['s'] = shape
1005 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs)
1006 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes])
1008 def get_BC(self, kind):
1009 """
1010 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
1011 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
1012 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
1014 Args:
1015 kind ('integral' or 'nyquist'): Kind of BC
1017 Returns:
1018 self.xp.ndarray: Boundary condition row
1019 """
1020 if kind.lower() == 'integral':
1021 return self.get_integ_BC_row()
1022 elif kind.lower() == 'nyquist':
1023 assert (
1024 self.N % 2 == 0
1025 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
1026 BC = self.xp.zeros(self.N)
1027 BC[self.get_Nyquist_mode_index()] = 1
1028 return BC
1029 else:
1030 return super().get_BC(kind)
1032 def get_Nyquist_mode_index(self):
1033 """
1034 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
1035 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
1036 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
1037 after.
1039 Returns:
1040 int: Index of the Nyquist mode
1041 """
1042 k = self.get_wavenumbers()
1043 Nyquist_mode = min(k)
1044 return self.xp.where(k == Nyquist_mode)[0][0]
1046 def get_integ_BC_row(self):
1047 """
1048 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
1049 """
1050 me = self.xp.zeros(self.N)
1051 me[0] = self.L / self.N
1052 return me
1055class SpectralHelper:
1056 """
1057 This class has three functions:
1058 - Easily assemble matrices containing multiple equations
1059 - Direct product of 1D bases to solve problems in more dimensions
1060 - Distribute the FFTs to facilitate concurrency.
1062 Attributes:
1063 comm (mpi4py.Intracomm): MPI communicator
1064 debug (bool): Perform additional checks at extra computational cost
1065 useGPU (bool): Whether to use GPUs
1066 axes (list): List of 1D bases
1067 components (list): List of strings of the names of components in the equations
1068 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
1069 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
1070 BCs (sparse matrix): Matrix containing only the BCs
1071 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
1072 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
1073 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
1074 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
1075 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
1076 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
1077 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
1078 init (tuple): This is the same `init` that is used throughout the problem classes
1079 init_forward (tuple): This is the equivalent of `init` in spectral space
1080 """
1082 xp = np
1083 fft_lib = scipy.fft
1084 sparse_lib = scipy.sparse
1085 linalg = scipy.sparse.linalg
1086 dtype = mesh
1087 fft_backend = 'scipy'
1088 fft_comm_backend = 'MPI'
1090 @classmethod
1091 def setup_GPU(cls):
1092 """switch to GPU modules"""
1093 import cupy as cp
1094 import cupyx.scipy.sparse as sparse_lib
1095 import cupyx.scipy.sparse.linalg as linalg
1096 import cupyx.scipy.fft as fft_lib
1097 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
1099 cls.xp = cp
1100 cls.sparse_lib = sparse_lib
1101 cls.linalg = linalg
1103 cls.fft_lib = fft_lib
1104 cls.fft_backend = 'cupyx-scipy'
1105 cls.fft_comm_backend = 'NCCL'
1107 cls.dtype = cupy_mesh
1109 @classmethod
1110 def setup_CPU(cls, useFFTW=False):
1111 """switch to CPU modules"""
1113 cls.xp = np
1114 cls.sparse_lib = scipy.sparse
1115 cls.linalg = scipy.sparse.linalg
1117 if useFFTW:
1118 from mpi4py_fft import fftw
1120 cls.fft_backend = 'fftw'
1121 cls.fft_lib = fftw
1122 else:
1123 cls.fft_backend = 'scipy'
1124 cls.fft_lib = scipy.fft
1126 cls.fft_comm_backend = 'MPI'
1127 cls.dtype = mesh
1129 def __init__(self, comm=None, useGPU=False, debug=False):
1130 """
1131 Constructor
1133 Args:
1134 comm (mpi4py.Intracomm): MPI communicator
1135 useGPU (bool): Whether to use GPUs
1136 debug (bool): Perform additional checks at extra computational cost
1137 """
1138 self.comm = comm
1139 self.debug = debug
1140 self.useGPU = useGPU
1142 if useGPU:
1143 self.setup_GPU()
1144 else:
1145 self.setup_CPU()
1147 self.axes = []
1148 self.components = []
1150 self.full_BCs = []
1151 self.BC_mat = None
1152 self.BCs = None
1154 self.fft_cache = {}
1156 self.logger = logging.getLogger(name='Spectral Discretization')
1157 if debug:
1158 self.logger.setLevel(logging.DEBUG)
1160 @property
1161 def u_init(self):
1162 """
1163 Get empty data container in physical space
1164 """
1165 return self.dtype(self.init)
1167 @property
1168 def u_init_forward(self):
1169 """
1170 Get empty data container in spectral space
1171 """
1172 return self.dtype(self.init_forward)
1174 @property
1175 def u_init_physical(self):
1176 """
1177 Get empty data container in physical space
1178 """
1179 return self.dtype(self.init_physical)
1181 @property
1182 def shape(self):
1183 """
1184 Get shape of individual solution component
1185 """
1186 return self.init[0][1:]
1188 @property
1189 def ndim(self):
1190 return len(self.axes)
1192 @property
1193 def ncomponents(self):
1194 return len(self.components)
1196 @property
1197 def V(self):
1198 """
1199 Get domain volume
1200 """
1201 return np.prod([me.L for me in self.axes])
1203 def add_axis(self, base, *args, **kwargs):
1204 """
1205 Add an axis to the domain by deciding on suitable 1D base.
1206 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
1207 bases for possible arguments.
1209 Args:
1210 base (str): 1D spectral method
1211 """
1212 kwargs['useGPU'] = self.useGPU
1214 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
1215 self.axes.append(ChebychevHelper(*args, **kwargs))
1216 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
1217 self.axes.append(FFTHelper(*args, **kwargs))
1218 elif base.lower() in ['ultraspherical', 'gegenbauer']:
1219 self.axes.append(UltrasphericalHelper(*args, **kwargs))
1220 else:
1221 raise NotImplementedError(f'{base=!r} is not implemented!')
1222 self.axes[-1].xp = self.xp
1223 self.axes[-1].sparse_lib = self.sparse_lib
1225 def add_component(self, name):
1226 """
1227 Add solution component(s).
1229 Args:
1230 name (str or list of strings): Name(s) of component(s)
1231 """
1232 if type(name) in [list, tuple]:
1233 for me in name:
1234 self.add_component(me)
1235 elif type(name) in [str]:
1236 if name in self.components:
1237 raise Exception(f'{name=!r} is already added to this problem!')
1238 self.components.append(name)
1239 else:
1240 raise NotImplementedError
1242 def index(self, name):
1243 """
1244 Get the index of component `name`.
1246 Args:
1247 name (str or list of strings): Name(s) of component(s)
1249 Returns:
1250 int: Index of the component
1251 """
1252 if type(name) in [str, int]:
1253 return self.components.index(name)
1254 elif type(name) in [list, tuple]:
1255 return (self.index(me) for me in name)
1256 else:
1257 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
1259 def get_empty_operator_matrix(self, diag=False):
1260 """
1261 Return a matrix of operators to be filled with the connections between the solution components.
1263 Args:
1264 diag (bool): Whether operator is block-diagonal
1266 Returns:
1267 list containing sparse zeros
1268 """
1269 S = len(self.components)
1270 O = self.get_Id() * 0
1271 if diag:
1272 return [O for _ in range(S)]
1273 else:
1274 return [[O for _ in range(S)] for _ in range(S)]
1276 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
1277 """
1278 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
1279 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
1280 Typically, the last line that does not contain a BC is the best choice.
1281 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1283 Args:
1284 axis (int): Axis you want to add the BC to
1285 kind (str): kind of BC, e.g. Dirichlet
1286 line (int): Line you want the BC to go in
1287 scalar (bool): Put the BC in all space positions in the other direction
1289 Returns:
1290 sparse matrix containing the BC
1291 """
1292 sp = scipy.sparse
1294 base = self.axes[axis]
1296 BC = sp.eye(base.N).tolil() * 0
1297 if self.useGPU:
1298 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1299 else:
1300 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1302 ndim = len(self.axes)
1303 if ndim == 1:
1304 mat = self.sparse_lib.csc_matrix(BC)
1305 elif ndim == 2:
1306 axis2 = (axis + 1) % ndim
1308 if scalar:
1309 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1310 else:
1311 _Id = self.axes[axis2].get_Id()
1313 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1315 mats = [
1316 None,
1317 ] * ndim
1318 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1319 mats[axis2] = Id
1320 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1321 elif ndim == 3:
1322 mats = [
1323 None,
1324 ] * ndim
1326 for ax in range(ndim):
1327 if ax == axis:
1328 continue
1330 if scalar:
1331 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1)))
1332 else:
1333 _Id = self.axes[ax].get_Id()
1335 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax)
1337 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1339 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
1340 else:
1341 raise NotImplementedError(
1342 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
1343 )
1344 mat = self.eliminate_zeros(mat)
1345 return mat
1347 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1348 """
1349 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1350 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1351 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1353 Args:
1354 component (str): Name of the component the BC should act on
1355 equation (str): Name of the equation for the component you want to put the BC in
1356 axis (int): Axis you want to add the BC to
1357 kind (str): kind of BC, e.g. Dirichlet
1358 v: Value of the BC
1359 line (int): Line you want the BC to go in
1360 scalar (bool): Put the BC in all space positions in the other direction
1361 """
1362 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1363 _BC = self.eliminate_zeros(_BC)
1364 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1366 if scalar:
1367 slices = [self.index(equation)] + [
1368 0,
1369 ] * self.ndim
1370 slices[axis + 1] = line
1371 else:
1372 slices = (
1373 [self.index(equation)]
1374 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1375 + [line]
1376 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1377 )
1378 N = self.axes[axis].N
1379 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]:
1380 self.BC_rhs_mask[(*slices,)] = False
1382 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1383 """
1384 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1385 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1386 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1388 Args:
1389 component (str): Name of the component the BC should act on
1390 equation (str): Name of the equation for the component you want to put the BC in
1391 axis (int): Axis you want to add the BC to
1392 kind (str): kind of BC, e.g. Dirichlet
1393 v: Value of the BC
1394 line (int): Line you want the BC to go in
1395 scalar (bool): Put the BC in all space positions in the other direction
1396 """
1397 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1398 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1399 self.full_BCs += [
1400 {
1401 'component': component,
1402 'equation': equation,
1403 'axis': axis,
1404 'kind': kind,
1405 'v': v,
1406 'line': line,
1407 'scalar': scalar,
1408 **kwargs,
1409 }
1410 ]
1412 if scalar:
1413 slices = [self.index(equation)] + [
1414 0,
1415 ] * self.ndim
1416 slices[axis + 1] = line
1417 if self.comm:
1418 if self.comm.rank == 0:
1419 self.BC_rhs_mask[(*slices,)] = True
1420 else:
1421 self.BC_rhs_mask[(*slices,)] = True
1422 else:
1423 slices = [self.index(equation), *self.global_slice(True)]
1424 N = self.axes[axis].N
1425 if (N + line) % N in self.get_indices(True)[axis]:
1426 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1427 self.BC_rhs_mask[(*slices,)] = True
1429 def setup_BCs(self):
1430 """
1431 Convert the list of lists of BCs to the boundary condition operator.
1432 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1433 condition. This method sets up a suitable sparse matrix to do this.
1434 """
1435 sp = self.sparse_lib
1436 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1437 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1439 diags = self.xp.ones(self.BCs.shape[0])
1440 diags[self.BC_zero_index] = 0
1441 self.BC_line_zero_matrix = sp.diags(diags).tocsc()
1443 # prepare BCs in spectral space to easily add to the RHS
1444 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1445 self.rhs_BCs_hat = self.transform(rhs_BCs).view(self.xp.ndarray)
1446 del self.BC_rhs_mask
1448 def check_BCs(self, u):
1449 """
1450 Check that the solution satisfies the boundary conditions
1452 Args:
1453 u: The solution you want to check
1454 """
1455 assert self.ndim < 3
1456 for axis in range(self.ndim):
1457 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1459 if len(BCs) > 0:
1460 u_hat = self.transform(u, axes=(axis - self.ndim,))
1461 for BC in BCs:
1462 kwargs = {
1463 key: value
1464 for key, value in BC.items()
1465 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1466 }
1468 if axis == 0:
1469 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1470 elif axis == 1:
1471 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1472 want = BC['v']
1473 assert self.xp.allclose(
1474 get, want
1475 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1477 def put_BCs_in_matrix(self, A):
1478 """
1479 Put the boundary conditions in a matrix by replacing rows with BCs.
1480 """
1481 return self.BC_line_zero_matrix @ A + self.BCs
1483 def put_BCs_in_rhs_hat(self, rhs_hat):
1484 """
1485 Put the BCs in the right hand side in spectral space for solving.
1486 This function needs no transforms and caches a mask for faster subsequent use.
1488 Args:
1489 rhs_hat: Right hand side in spectral space
1491 Returns:
1492 rhs in spectral space with BCs
1493 """
1494 if not hasattr(self, '_rhs_hat_zero_mask'):
1495 """
1496 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1497 by the boundary conditions. The mask is then cached.
1498 """
1499 self._rhs_hat_zero_mask = self.newDistArray(forward_output=True).astype(bool).view(self.xp.ndarray)
1501 for axis in range(self.ndim):
1502 for bc in self.full_BCs:
1503 if axis == bc['axis']:
1504 slices = [self.index(bc['equation']), *self.global_slice(True)]
1505 N = self.axes[axis].N
1506 line = bc['line']
1507 if (N + line) % N in self.get_indices(True)[axis]:
1508 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1509 self._rhs_hat_zero_mask[(*slices,)] = True
1511 rhs_hat[self._rhs_hat_zero_mask] = 0
1512 return rhs_hat + self.rhs_BCs_hat
1514 def put_BCs_in_rhs(self, rhs):
1515 """
1516 Put the BCs in the right hand side for solving.
1517 This function will transform along each axis individually and add all BCs in that axis.
1518 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1520 Args:
1521 rhs: Right hand side in physical space
1523 Returns:
1524 rhs in physical space with BCs
1525 """
1526 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1528 ndim = self.ndim
1530 for axis in range(ndim):
1531 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1533 for bc in self.full_BCs:
1535 if axis == bc['axis']:
1536 _slice = [self.index(bc['equation']), *self.global_slice(True)]
1538 N = self.axes[axis].N
1539 line = bc['line']
1540 if (N + line) % N in self.get_indices(True)[axis]:
1541 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1542 _rhs_hat[(*_slice,)] = bc['v']
1544 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1546 return rhs
1548 def add_equation_lhs(self, A, equation, relations):
1549 """
1550 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1551 that you will convert to an operator later.
1553 Example:
1554 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1555 preconditioning:
1557 .. code-block:: python
1558 helper = SpectralHelper()
1560 helper.add_axis(base='chebychev', N=8)
1561 helper.add_component(['u', 'ux'])
1562 helper.setup_fft()
1564 I = helper.get_Id()
1565 Dx = helper.get_differentiation_matrix(axes=(0,))
1566 T2U = helper.get_basis_change_matrix('T2U')
1568 L_lhs = {
1569 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1570 'u': {'ux': -(T2U @ Dx)},
1571 }
1573 operator = helper.get_empty_operator_matrix()
1574 for line, equation in L_lhs.items():
1575 helper.add_equation_lhs(operator, line, equation)
1577 L = helper.convert_operator_matrix_to_operator(operator)
1579 Args:
1580 A (list of lists of sparse matrices): The operator to be
1581 equation (str): The equation of the component you want this in
1582 relations: (dict): Relations between quantities
1583 """
1584 for k, v in relations.items():
1585 A[self.index(equation)][self.index(k)] = v
1587 def eliminate_zeros(self, A):
1588 """
1589 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
1590 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
1591 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.
1593 Args:
1594 A: sparse matrix to be pruned
1596 Returns:
1597 CSC sparse matrix
1598 """
1599 if self.useGPU:
1600 A = A.get()
1601 A = A.tocsc()
1602 A.eliminate_zeros()
1603 if self.useGPU:
1604 A = self.sparse_lib.csc_matrix(A)
1605 return A
1607 def convert_operator_matrix_to_operator(self, M):
1608 """
1609 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1610 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1612 Args:
1613 M (list of lists of sparse matrices): The operator to be
1615 Returns:
1616 sparse linear operator
1617 """
1618 if len(self.components) == 1:
1619 op = M[0][0]
1620 else:
1621 op = self.sparse_lib.bmat(M, format='csc')
1623 op = self.eliminate_zeros(op)
1624 return op
1626 def get_wavenumbers(self):
1627 """
1628 Get grid in spectral space
1629 """
1630 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))]
1631 return self.xp.meshgrid(*grids, indexing='ij')
1633 def get_grid(self, forward_output=False):
1634 """
1635 Get grid in physical space
1636 """
1637 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1638 return self.xp.meshgrid(*grids, indexing='ij')
1640 def get_indices(self, forward_output=True):
1641 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1643 @cache
1644 def get_pfft(self, axes=None, padding=None, grid=None):
1645 if self.ndim == 1 or self.comm is None:
1646 return None
1647 from mpi4py_fft import PFFT, newDistArray
1649 axes = tuple(i for i in range(self.ndim)) if axes is None else axes
1650 padding = list(padding if padding else [1.0 for _ in range(self.ndim)])
1652 def no_transform(u, *args, **kwargs):
1653 return u
1655 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)}
1656 for i in axes:
1657 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform)
1659 # "transform" all axes to ensure consistent shapes.
1660 # Transform non-distributable axes last to ensure they are aligned
1661 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes))
1662 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted(
1663 [axis for axis in _axes if self.axes[axis].distributable]
1664 + [axis for axis in range(self.ndim) if axis not in _axes]
1665 )
1667 pfft = PFFT(
1668 comm=self.comm,
1669 shape=self.global_shape[1:],
1670 axes=_axes, # TODO: control the order of the transforms better
1671 dtype='D',
1672 collapse=False,
1673 backend=self.fft_backend,
1674 comm_backend=self.fft_comm_backend,
1675 padding=padding,
1676 transforms=transforms,
1677 grid=grid,
1678 )
1680 # do a transform to do the planning
1681 _u = newDistArray(pfft, forward_output=False)
1682 pfft.backward(pfft.forward(_u))
1683 return pfft
1685 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1686 """
1687 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1689 Args:
1690 axes (tuple): Axes you want to transform over
1691 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1692 padding (tuple): Padding for dealiasing
1693 shape (tuple): Shape of the transform
1695 Returns:
1696 transform
1697 """
1698 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1699 shape = self.global_shape[1:] if shape is None else shape
1700 padding = (
1701 [
1702 1,
1703 ]
1704 * self.ndim
1705 if padding is None
1706 else padding
1707 )
1708 key = (axes, direction, tuple(padding), tuple(shape))
1710 if key not in self.fft_cache.keys():
1711 if self.comm is None:
1712 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1714 if direction == 'forward':
1715 self.fft_cache[key] = self.xp.fft.fftn
1716 elif direction == 'backward':
1717 self.fft_cache[key] = self.xp.fft.ifftn
1718 elif direction == 'object':
1719 self.fft_cache[key] = None
1720 else:
1721 if direction == 'object':
1722 from mpi4py_fft import PFFT
1724 _fft = PFFT(
1725 comm=self.comm,
1726 shape=shape,
1727 axes=sorted(axes),
1728 dtype='D',
1729 collapse=False,
1730 backend=self.fft_backend,
1731 comm_backend=self.fft_comm_backend,
1732 padding=padding,
1733 )
1734 else:
1735 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1737 if direction == 'forward':
1738 self.fft_cache[key] = _fft.forward
1739 elif direction == 'backward':
1740 self.fft_cache[key] = _fft.backward
1741 elif direction == 'object':
1742 self.fft_cache[key] = _fft
1744 return self.fft_cache[key]
1746 def local_slice(self, forward_output=True):
1747 if self.fft_obj:
1748 return self.get_pfft().local_slice(forward_output=forward_output)
1749 else:
1750 return [slice(0, me.N) for me in self.axes]
1752 def global_slice(self, forward_output=True):
1753 if self.fft_obj:
1754 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)]
1755 else:
1756 return self.local_slice(forward_output=forward_output)
1758 def setup_fft(self, real_spectral_coefficients=False):
1759 """
1760 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1761 This must also be called before setting up any BCs.
1763 Args:
1764 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1765 """
1766 if len(self.components) == 0:
1767 self.add_component('u')
1769 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1771 axes = tuple(i for i in range(len(self.axes)))
1772 self.fft_obj = self.get_pfft(axes=axes)
1774 self.init = (
1775 np.empty(shape=self.global_shape)[
1776 (
1777 ...,
1778 *self.local_slice(False),
1779 )
1780 ].shape,
1781 self.comm,
1782 np.dtype('float'),
1783 )
1784 self.init_physical = (
1785 np.empty(shape=self.global_shape)[
1786 (
1787 ...,
1788 *self.local_slice(False),
1789 )
1790 ].shape,
1791 self.comm,
1792 np.dtype('float'),
1793 )
1794 self.init_forward = (
1795 np.empty(shape=self.global_shape)[
1796 (
1797 ...,
1798 *self.local_slice(True),
1799 )
1800 ].shape,
1801 self.comm,
1802 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1803 )
1805 self.BC_mat = self.get_empty_operator_matrix()
1806 self.BC_rhs_mask = self.newDistArray().astype(bool)
1808 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False):
1809 """
1810 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but
1811 takes care of all the solution components in the tensor.
1812 """
1813 if self.comm is None:
1814 return self.xp.zeros(self.init[0], dtype=self.init[2])
1815 from mpi4py_fft.distarray import DistArray
1817 pfft = pfft if pfft else self.get_pfft()
1818 if pfft is None:
1819 if forward_output:
1820 return self.u_init_forward
1821 else:
1822 return self.u_init
1824 global_shape = pfft.global_shape(forward_output)
1825 p0 = pfft.pencil[forward_output]
1826 if forward_output is True:
1827 dtype = pfft.forward.output_array.dtype
1828 else:
1829 dtype = pfft.forward.input_array.dtype
1830 global_shape = (self.ncomponents,) * rank + global_shape
1832 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
1833 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
1834 else:
1835 darraycls = DistArray
1837 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank)
1838 return z.v if view else z
1840 def infer_alignment(self, u, forward_output, padding=None, **kwargs):
1841 if self.comm is None:
1842 return [0]
1844 def _alignment(pfft):
1845 _arr = self.newDistArray(pfft, forward_output=forward_output)
1846 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]]
1847 return _aligned_axes
1849 if padding is None:
1850 pfft = self.get_pfft(**kwargs)
1851 aligned_axes = _alignment(pfft)
1852 else:
1853 if self.ndim == 2:
1854 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)]
1855 elif self.ndim == 3:
1856 padding_options = [
1857 (1.0, 1.0, padding[2]),
1858 (1.0, padding[1], 1.0),
1859 (padding[0], 1.0, 1.0),
1860 (1.0, padding[1], padding[2]),
1861 (padding[0], 1.0, padding[2]),
1862 (padding[0], padding[1], 1.0),
1863 padding,
1864 (1.0, 1.0, 1.0),
1865 ]
1866 else:
1867 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!')
1868 for _padding in padding_options:
1869 pfft = self.get_pfft(padding=_padding, **kwargs)
1870 aligned_axes = _alignment(pfft)
1871 if len(aligned_axes) > 0:
1872 self.logger.debug(
1873 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}'
1874 )
1875 break
1877 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!'
1878 return aligned_axes
1880 def redistribute(self, u, axis, forward_output, **kwargs):
1881 if self.comm is None:
1882 return u
1884 pfft = self.get_pfft(**kwargs)
1885 _arr = self.newDistArray(pfft, forward_output=forward_output)
1887 if 'Dist' in type(u).__name__ and False:
1888 try:
1889 u.redistribute(out=_arr)
1890 return _arr
1891 except AssertionError:
1892 pass
1894 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs)
1895 for alignment in u_alignment:
1896 _arr = _arr.redistribute(alignment)
1897 if _arr.shape == u.shape:
1898 _arr[...] = u
1899 return _arr.redistribute(axis)
1901 raise Exception(
1902 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}'
1903 )
1905 def transform(self, u, *args, axes=None, padding=None, **kwargs):
1906 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1908 if pfft is None:
1909 axes = axes if axes else tuple(i for i in range(self.ndim))
1910 u_hat = u.copy()
1911 for i in axes:
1912 _axis = 1 + i if i >= 0 else i
1913 u_hat = self.axes[i].transform(u_hat, axes=(_axis,))
1914 return u_hat
1916 _in = self.newDistArray(pfft, forward_output=False, rank=1)
1917 _out = self.newDistArray(pfft, forward_output=True, rank=1)
1919 if _in.shape == u.shape:
1920 _in[...] = u
1921 else:
1922 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs)
1924 for i in range(self.ncomponents):
1925 pfft.forward(_in[i], _out[i], normalize=False)
1927 if padding is not None:
1928 _out /= np.prod(padding)
1929 return _out
1931 def itransform(self, u, *args, axes=None, padding=None, **kwargs):
1932 if padding is not None:
1933 assert all(
1934 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim)
1935 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer'
1937 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1938 if pfft is None:
1939 axes = axes if axes else tuple(i for i in range(self.ndim))
1940 u_hat = u.copy()
1941 for i in axes:
1942 _axis = 1 + i if i >= 0 else i
1943 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,))
1944 return u_hat
1946 _in = self.newDistArray(pfft, forward_output=True, rank=1)
1947 _out = self.newDistArray(pfft, forward_output=False, rank=1)
1949 if _in.shape == u.shape:
1950 _in[...] = u
1951 else:
1952 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs)
1954 for i in range(self.ncomponents):
1955 pfft.backward(_in[i], _out[i], normalize=True)
1957 if padding is not None:
1958 _out *= np.prod(padding)
1959 return _out
1961 def get_local_slice_of_1D_matrix(self, M, axis):
1962 """
1963 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1964 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can
1965 use this method to get the part corresponding to the modes carried by this rank.
1967 Args:
1968 M (sparse matrix): Global 1D matrix you want to get the local version of
1969 axis (int): Direction in which you want the local version. You will get the global matrix in other directions.
1971 Returns:
1972 sparse local matrix
1973 """
1974 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]]
1976 def expand_matrix_ND(self, matrix, aligned):
1977 sp = self.sparse_lib
1978 axes = np.delete(np.arange(self.ndim), aligned)
1979 ndim = len(axes) + 1
1981 if ndim == 1:
1982 mat = matrix
1983 elif ndim == 2:
1984 axis = axes[0]
1985 I1D = sp.eye(self.axes[axis].N)
1987 mats = [None] * ndim
1988 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1989 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1991 mat = sp.kron(*mats)
1992 elif ndim == 3:
1994 mats = [None] * ndim
1995 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1996 for axis in axes:
1997 I1D = sp.eye(self.axes[axis].N)
1998 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
2000 mat = sp.kron(mats[0], sp.kron(*mats[1:]))
2002 else:
2003 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
2005 mat = self.eliminate_zeros(mat)
2006 return mat
2008 def get_filter_matrix(self, axis, **kwargs):
2009 """
2010 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
2011 admissible.
2013 Returns:
2014 sparse bandpass matrix
2015 """
2016 if self.ndim == 1:
2017 return self.axes[0].get_filter_matrix(**kwargs)
2019 mats = [base.get_Id() for base in self.axes]
2020 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
2021 return self.sparse_lib.kron(*mats)
2023 def get_differentiation_matrix(self, axes, **kwargs):
2024 """
2025 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
2027 Args:
2028 axes (tuple): Axes along which to differentiate.
2030 Returns:
2031 sparse differentiation matrix
2032 """
2033 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
2034 for axis in axes[1:]:
2035 _D = self.axes[axis].get_differentiation_matrix(**kwargs)
2036 D = D @ self.expand_matrix_ND(_D, axis)
2038 self.logger.debug(f'Set up differentiation matrix along axes {axes} with kwargs {kwargs}')
2039 return D
2041 def get_integration_matrix(self, axes):
2042 """
2043 Get integration matrix to integrate along specified axis.
2045 Args:
2046 axes (tuple): Axes along which to integrate over.
2048 Returns:
2049 sparse integration matrix
2050 """
2051 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
2052 for axis in axes[1:]:
2053 _S = self.axes[axis].get_integration_matrix()
2054 S = S @ self.expand_matrix_ND(_S, axis)
2056 return S
2058 def get_Id(self):
2059 """
2060 Get identity matrix
2062 Returns:
2063 sparse identity matrix
2064 """
2065 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
2066 for axis in range(1, self.ndim):
2067 _I = self.axes[axis].get_Id()
2068 I = I @ self.expand_matrix_ND(_I, axis)
2069 return I
2071 def get_Dirichlet_recombination_matrix(self, axis=-1):
2072 """
2073 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
2075 Args:
2076 axis (int): Axis you discretized with Chebychev
2078 Returns:
2079 sparse matrix
2080 """
2081 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
2082 return self.expand_matrix_ND(C1D, axis)
2084 def get_basis_change_matrix(self, axes=None, **kwargs):
2085 """
2086 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
2087 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
2089 Args:
2090 axes (tuple): Axes along which to change basis.
2092 Returns:
2093 sparse basis change matrix
2094 """
2095 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
2097 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
2098 for axis in axes[1:]:
2099 _C = self.axes[axis].get_basis_change_matrix(**kwargs)
2100 C = C @ self.expand_matrix_ND(_C, axis)
2102 self.logger.debug(f'Set up basis change matrix along axes {axes} with kwargs {kwargs}')
2103 return C