Coverage for pySDC / helpers / spectral_helper.py: 83%
864 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 07:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 07:06 +0000
1import numpy as np
2import scipy
3from pySDC.implementations.datatype_classes.mesh import mesh
4from scipy.special import factorial
5from functools import partial, wraps
6import logging
9def cache(func):
10 """
11 Decorator for caching return values of functions.
12 This is very similar to `functools.cache`, but without the memory leaks (see
13 https://docs.astral.sh/ruff/rules/cached-instance-method/).
15 Example:
17 .. code-block:: python
19 num_calls = 0
21 @cache
22 def increment(x):
23 num_calls += 1
24 return x + 1
26 increment(0) # returns 1, num_calls = 1
27 increment(1) # returns 2, num_calls = 2
28 increment(0) # returns 1, num_calls = 2
31 Args:
32 func (function): The function you want to cache the return value of
34 Returns:
35 return value of func
36 """
37 attr_cache = f"_{func.__name__}_cache"
39 @wraps(func)
40 def wrapper(self, *args, **kwargs):
41 if not hasattr(self, attr_cache):
42 setattr(self, attr_cache, {})
44 cache = getattr(self, attr_cache)
46 key = (args, frozenset(kwargs.items()))
47 if key in cache:
48 return cache[key]
49 result = func(self, *args, **kwargs)
50 cache[key] = result
51 return result
53 return wrapper
56class vkFFT(object):
57 """
58 pyVkFFT FFT backend.
59 The special feature of vkFFT is fast DCT on GPU with cached plans.
60 """
62 cached_plans = {}
64 @staticmethod
65 def is_complex(x):
66 return 'complex' in str(x.dtype)
68 @staticmethod
69 def get_plan(transform_type, shape, dtype, axes, norm):
70 from pyvkfft.cuda import VkFFTApp
72 assert norm == 'backward'
74 key = f'{transform_type=}, {shape=}, {dtype=}, {axes=}, {norm=}'
76 if key not in vkFFT.cached_plans.keys():
78 kwargs = {}
80 if transform_type == 'dct':
81 kwargs['dct'] = 2
83 vkFFT.cached_plans[key] = VkFFTApp(shape, dtype, len(axes), axes=axes, norm=norm, **kwargs)
85 logger = logging.getLogger(name='VkFFT')
86 logger.debug(f'Cached plan for VkFFT: {key}')
87 return vkFFT.cached_plans[key]
89 @staticmethod
90 def fftn(x, s=None, axes=None, norm='backward', overwrite_x=False):
91 assert not overwrite_x # for consistent interface with scipy
92 assert norm == 'backward' # for consistent interface with scipy
93 plan = vkFFT.get_plan(
94 transform_type='fft',
95 shape=x.shape,
96 dtype=x.dtype,
97 axes=axes,
98 norm=norm,
99 )
100 _x = x.copy() + 0j # cast to complex
101 plan.fft(_x)
102 return _x
104 @staticmethod
105 def ifftn(x, s=None, axes=None, norm='forward', overwrite_x=False):
106 assert norm == 'forward'
107 assert not overwrite_x # for consistent interface with scipy
109 norm = 'backward'
110 plan = vkFFT.get_plan(
111 transform_type='fft',
112 shape=x.shape,
113 dtype=x.dtype,
114 axes=axes,
115 norm=norm,
116 )
117 _x = x.copy() + 0j # promote to complex
118 plan.ifft(_x)
119 return _x * sum(x.shape[i] for i in axes)
121 @staticmethod
122 def dctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False):
123 assert type == 2 # for consistent interface with scipy
124 assert not overwrite_x # for consistent interface with scipy
126 is_complex = vkFFT.is_complex(x)
128 dtype = x.dtype if not is_complex else x.real.dtype
130 plan = vkFFT.get_plan(
131 transform_type='dct',
132 shape=x.shape,
133 dtype=dtype,
134 axes=axes,
135 norm=norm,
136 )
138 if is_complex:
139 x_real = x.real.copy()
140 x_imag = x.imag.copy()
142 plan.fft(x_real)
143 plan.fft(x_imag)
145 return x_real + 1j * x_imag
146 else:
147 _x = x.copy()
148 plan.fft(x)
149 return x
151 @staticmethod
152 def idctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False):
153 assert type == 2 # for consistent interface with scipy
154 assert not overwrite_x # for consistent interface with scipy
156 is_complex = vkFFT.is_complex(x)
157 dtype = x.dtype if not is_complex else x.real.dtype
159 plan = vkFFT.get_plan(
160 transform_type='dct',
161 shape=x.shape,
162 dtype=dtype,
163 axes=axes,
164 norm=norm,
165 )
167 if is_complex:
168 x_real = x.real.copy()
169 x_imag = x.imag.copy()
171 plan.ifft(x_real)
172 plan.ifft(x_imag)
174 return x_real + 1j * x_imag
175 else:
176 _x = x.copy()
177 plan.ifft(x)
178 return x
181class SpectralHelper1D:
182 """
183 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that
184 all bases need to have.
186 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable
187 the code for GPUs.
189 Attributes:
190 N (int): Resolution
191 x0 (float): Coordinate of left boundary
192 x1 (float): Coordinate of right boundary
193 L (float): Length of the domain
194 useGPU (bool): Whether to use GPUs
196 """
198 fft_lib = scipy.fft
199 sparse_lib = scipy.sparse
200 linalg = scipy.sparse.linalg
201 xp = np
202 distributable = False
204 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False):
205 """
206 Constructor
208 Args:
209 N (int): Resolution
210 x0 (float): Coordinate of left boundary
211 x1 (float): Coordinate of right boundary
212 useGPU (bool): Whether to use GPUs
213 useFFTW (bool): Whether to use FFTW for the transforms
214 """
215 self.N = N
216 self.x0 = x0
217 self.x1 = x1
218 self.L = x1 - x0
219 self.useGPU = useGPU
220 self.plans = {}
221 self.logger = logging.getLogger(name=type(self).__name__)
223 if useGPU:
224 self.setup_GPU()
225 self.logger.debug('Set up for GPU')
226 else:
227 self.setup_CPU(useFFTW=useFFTW)
229 if useGPU and useFFTW:
230 raise ValueError('Please run either on GPUs or with FFTW, not both!')
232 @classmethod
233 def setup_GPU(cls):
234 """switch to GPU modules"""
235 import cupy as cp
236 import cupyx.scipy.sparse as sparse_lib
237 import cupyx.scipy.sparse.linalg as linalg
238 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
240 cls.xp = cp
241 cls.sparse_lib = sparse_lib
242 cls.linalg = linalg
243 cls.fft_lib = vkFFT
245 @classmethod
246 def setup_CPU(cls, useFFTW=False):
247 """switch to CPU modules"""
249 cls.xp = np
250 cls.sparse_lib = scipy.sparse
251 cls.linalg = scipy.sparse.linalg
253 if useFFTW:
254 from mpi4py_fft import fftw
256 cls.fft_backend = 'fftw'
257 cls.fft_lib = fftw
258 else:
259 cls.fft_backend = 'scipy'
260 cls.fft_lib = scipy.fft
262 cls.fft_comm_backend = 'MPI'
263 cls.dtype = mesh
265 def get_Id(self):
266 """
267 Get identity matrix
269 Returns:
270 sparse diagonal identity matrix
271 """
272 return self.sparse_lib.eye(self.N)
274 def get_zero(self):
275 """
276 Get a matrix with all zeros of the correct size.
278 Returns:
279 sparse matrix with zeros everywhere
280 """
281 return 0 * self.get_Id()
283 def get_differentiation_matrix(self):
284 raise NotImplementedError()
286 def get_integration_matrix(self):
287 raise NotImplementedError()
289 def get_integration_weights(self):
290 """Weights for integration across entire domain"""
291 raise NotImplementedError()
293 def get_wavenumbers(self):
294 """
295 Get the grid in spectral space
296 """
297 raise NotImplementedError
299 def get_empty_operator_matrix(self, S, O):
300 """
301 Return a matrix of operators to be filled with the connections between the solution components.
303 Args:
304 S (int): Number of components in the solution
305 O (sparse matrix): Zero matrix used for initialization
307 Returns:
308 list of lists containing sparse zeros
309 """
310 return [[O for _ in range(S)] for _ in range(S)]
312 def get_basis_change_matrix(self, *args, **kwargs):
313 """
314 Some spectral discretization change the basis during differentiation. This method can be used to transfer
315 between the various bases.
317 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi-
318 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization.
319 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different
320 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative
321 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical
322 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix
323 you want. This is what the `SpectralHelper` does when you call the method of the same name on it.
325 Returns:
326 sparse bases change matrix
327 """
328 return self.sparse_lib.eye(self.N)
330 def get_BC(self, kind):
331 """
332 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically,
333 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in
334 the matrix are replaced by the boundary conditions as obtained by this method.
336 Args:
337 kind (str): The type of BC you want to implement please refer to the implementations of this method in the
338 individual 1D bases for what is implemented
340 Returns:
341 self.xp.array: Boundary condition
342 """
343 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!')
345 def get_filter_matrix(self, kmin=0, kmax=None):
346 """
347 Get a bandpass filter.
349 Args:
350 kmin (int): Lower limit of the bandpass filter
351 kmax (int): Upper limit of the bandpass filter
353 Returns:
354 sparse matrix
355 """
357 k = abs(self.get_wavenumbers())
359 kmax = max(k) if kmax is None else kmax
361 mask = self.xp.logical_or(k >= kmax, k < kmin)
363 if self.useGPU:
364 Id = self.get_Id().get()
365 else:
366 Id = self.get_Id()
367 F = Id.tolil()
368 F[:, mask] = 0
369 return F.tocsc()
371 def get_1dgrid(self):
372 """
373 Get the grid in physical space
375 Returns:
376 self.xp.array: Grid
377 """
378 raise NotImplementedError
381class ChebychevHelper(SpectralHelper1D):
382 """
383 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform
384 between physical and spectral space by discrete cosine transform.
385 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator
386 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to
387 be formulated in first order formulation.
389 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068).
390 """
392 def __init__(self, *args, x0=-1, x1=1, **kwargs):
393 """
394 Constructor.
395 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
396 may choose to run on GPUs via the `useGPU` argument.
398 Args:
399 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented
400 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented
401 """
402 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2
403 self.lin_trf_fac = (x1 - x0) / 2
404 self.lin_trf_off = (x1 + x0) / 2
405 super().__init__(*args, x0=x0, x1=x1, **kwargs)
407 self.norm = self.get_norm()
409 def get_1dgrid(self):
410 '''
411 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to
412 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you
413 need to do an affine transformation before any Chebychev business.
415 Returns:
416 numpy.ndarray: 1D grid
417 '''
418 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off
420 def get_wavenumbers(self):
421 """Get the domain in spectral space"""
422 return self.xp.arange(self.N)
424 @cache
425 def get_conv(self, name, N=None):
426 '''
427 Get conversion matrix between different kinds of polynomials. The supported kinds are
428 - T: Chebychev polynomials of first kind
429 - U: Chebychev polynomials of second kind
430 - D: Dirichlet recombination.
432 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U.
433 Once generates matrices are cached. So feel free to call the method as often as you like.
435 Args:
436 name (str): Conversion code, e.g. 'T2U'
437 N (int): Size of the matrix (optional)
439 Returns:
440 scipy.sparse: Sparse conversion matrix
441 '''
442 N = N if N else self.N
443 sp = self.sparse_lib
445 def get_forward_conv(name):
446 if name == 'T2U':
447 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0
448 mat[:, 0] *= 2
449 elif name == 'D2T':
450 mat = sp.eye(N) - sp.eye(N, k=2)
451 elif name[0] == name[-1]:
452 mat = self.sparse_lib.eye(self.N)
453 else:
454 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}')
455 return mat
457 try:
458 mat = get_forward_conv(name)
459 except NotImplementedError as E:
460 try:
461 fwd = get_forward_conv(name[::-1])
462 import scipy.sparse as sp
464 if self.sparse_lib == sp:
465 mat = self.sparse_lib.linalg.inv(fwd.tocsc())
466 else:
467 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get()))
468 except NotImplementedError:
469 raise NotImplementedError from E
471 return mat
473 def get_basis_change_matrix(self, conv='T2T', **kwargs):
474 """
475 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to
476 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a
477 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other
478 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`.
480 Args:
481 conv (str): Conversion code, i.e. T2U
483 Returns:
484 Sparse conversion matrix
485 """
486 return self.get_conv(conv)
488 def get_integration_matrix(self, lbnd=0):
489 """
490 Get matrix for integration
492 Args:
493 lbnd (float): Lower bound for integration, only 0 is currently implemented
495 Returns:
496 Sparse integration matrix
497 """
498 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U')
499 n = self.xp.arange(self.N)
500 if lbnd == 0:
501 S = S.tocsc()
502 S[0, 1::2] = (
503 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2]
504 * (-1) ** (self.xp.arange(self.N // 2))
505 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1))
506 ) * self.lin_trf_fac
507 else:
508 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.')
509 return S
511 def get_integration_weights(self):
512 """Weights for integration across entire domain"""
513 n = self.xp.arange(self.N, dtype=float)
515 weights = (-1) ** n + 1
516 weights[2:] /= 1 - (n**2)[2:]
518 weights /= 2 / self.L
519 return weights
521 def get_differentiation_matrix(self, p=1):
522 '''
523 Keep in mind that the T2T differentiation matrix is dense.
525 Args:
526 p (int): Derivative you want to compute
528 Returns:
529 numpy.ndarray: Differentiation matrix
530 '''
531 D = self.xp.zeros((self.N, self.N))
532 for j in range(self.N):
533 for k in range(j):
534 D[k, j] = 2 * j * ((j - k) % 2)
536 D[0, :] /= 2
537 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p
539 @cache
540 def get_norm(self, N=None):
541 '''
542 Get normalization for converting Chebychev coefficients and DCT
544 Args:
545 N (int, optional): Resolution
547 Returns:
548 self.xp.array: Normalization
549 '''
550 N = self.N if N is None else N
551 norm = self.xp.ones(N) / N
552 norm[0] /= 2
553 return norm
555 def transform(self, u, *args, axes=None, shape=None, **kwargs):
556 """
557 DCT along axes. `kwargs` will be passed on to the FFT library.
559 Args:
560 u: Data you want to transform
561 axes (tuple): Axes you want to transform along
563 Returns:
564 Data in spectral space
565 """
566 axes = axes if axes else tuple(i for i in range(u.ndim))
567 kwargs['s'] = shape
568 kwargs['norm'] = kwargs.get('norm', 'backward')
570 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs)
571 for axis in axes:
573 if self.N < trf.shape[axis]:
574 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
575 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies
576 # in the middle.
577 _trf = self.xp.zeros_like(trf)
578 N = self.N
579 N_pad = _trf.shape[axis] - N
580 end_first_half = N // 2 + 1
582 # copy first "half"
583 su = [slice(None)] * trf.ndim
584 su[axis] = slice(0, end_first_half)
585 _trf[tuple(su)] = trf[tuple(su)]
587 # copy second "half"
588 su = [slice(None)] * u.ndim
589 su[axis] = slice(end_first_half + N_pad, None)
590 s_u = [slice(None)] * u.ndim
591 s_u[axis] = slice(end_first_half, N)
592 _trf[tuple(su)] = trf[tuple(s_u)]
594 # # copy values to be cut
595 # su = [slice(None)] * u.ndim
596 # su[axis] = slice(end_first_half, end_first_half + N_pad)
597 # s_u = [slice(None)] * u.ndim
598 # s_u[axis] = slice(-N_pad, None)
599 # _trf[tuple(su)] = trf[tuple(s_u)]
601 trf = _trf
603 expansion = [np.newaxis for _ in u.shape]
604 expansion[axis] = slice(0, u.shape[axis], 1)
605 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1]
606 norm[: self.N] = self.norm
607 trf *= norm[(*expansion,)]
608 return trf
610 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
611 """
612 Inverse DCT along axis.
614 Args:
615 u: Data you want to transform
616 axes (tuple): Axes you want to transform along
618 Returns:
619 Data in physical space
620 """
621 axes = axes if axes else tuple(i for i in range(u.ndim))
622 kwargs['s'] = shape
623 kwargs['norm'] = kwargs.get('norm', 'backward')
624 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False)
626 for axis in axes:
628 if self.N == u.shape[axis]:
629 _u = u.copy()
630 else:
631 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are
632 # added in the middle rather than the end. We need to resort this here and put the padding in the end.
633 N = self.N
634 _u = self.xp.zeros_like(u)
636 # copy first half
637 su = [slice(None)] * u.ndim
638 su[axis] = slice(0, N // 2 + 1)
639 _u[tuple(su)] = u[tuple(su)]
641 # copy second half
642 su = [slice(None)] * u.ndim
643 su[axis] = slice(-(N // 2), None)
644 s_u = [slice(None)] * u.ndim
645 s_u[axis] = slice(N // 2, N // 2 + (N // 2))
646 _u[tuple(s_u)] = u[tuple(su)]
648 if N % 2 == 0:
649 su = [slice(None)] * u.ndim
650 su[axis] = N // 2
651 _u[tuple(su)] *= 2
653 # generate norm
654 expansion = [np.newaxis for _ in u.shape]
655 expansion[axis] = slice(0, u.shape[axis], 1)
656 norm = self.xp.ones(_u.shape[axis])
657 norm[: self.N] = self.norm
658 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N
660 _u /= norm[(*expansion,)]
662 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs)
664 def get_BC(self, kind, **kwargs):
665 """
666 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of
667 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to
668 set the BC.
670 Args:
671 kind ('integral' or 'dirichlet'): Kind of boundary condition you want
672 """
673 if kind.lower() == 'integral':
674 return self.get_integ_BC_row(**kwargs)
675 elif kind.lower() == 'dirichlet':
676 return self.get_Dirichlet_BC_row(**kwargs)
677 elif kind.lower() == 'neumann':
678 return self.get_Neumann_BC_row(**kwargs)
679 else:
680 return super().get_BC(kind)
682 def get_integ_BC_row(self):
683 """
684 Get a row for generating integral BCs with T polynomials.
685 It returns the values of the integrals of T polynomials over the entire interval.
687 Returns:
688 self.xp.ndarray: Row to put into a matrix
689 """
690 n = self.xp.arange(self.N) + 1
691 me = self.xp.zeros_like(n).astype(float)
692 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2)
693 me[0] = 2.0
694 return me
696 def get_Dirichlet_BC_row(self, x):
697 """
698 Get a row for generating Dirichlet BCs at x with T polynomials.
699 It returns the values of the T polynomials at x.
701 Args:
702 x (float): Position of the boundary condition
704 Returns:
705 self.xp.ndarray: Row to put into a matrix
706 """
707 if x == -1:
708 return (-1) ** self.xp.arange(self.N)
709 elif x == 1:
710 return self.xp.ones(self.N)
711 elif x == 0:
712 n = (1 + (-1) ** self.xp.arange(self.N)) / 2
713 n[2::4] *= -1
714 return n
715 else:
716 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!')
718 def get_Neumann_BC_row(self, x):
719 """
720 Get a row for generating Neumann BCs at x with T polynomials.
722 Args:
723 x (float): Position of the boundary condition
725 Returns:
726 self.xp.ndarray: Row to put into a matrix
727 """
728 n = self.xp.arange(self.N, dtype='D')
729 nn = n**2
730 if x == -1:
731 me = nn
732 me[1:] *= (-1) ** n[:-1]
733 return me
734 elif x == 1:
735 return nn
736 else:
737 raise NotImplementedError(f'Don\'t know how to generate Neumann BC\'s at {x=}!')
739 def get_Dirichlet_recombination_matrix(self):
740 '''
741 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions.
742 This makes for a good right preconditioner.
744 Returns:
745 scipy.sparse: Sparse conversion matrix
746 '''
747 N = self.N
748 sp = self.sparse_lib
750 return sp.eye(N) - sp.eye(N, k=2)
753class UltrasphericalHelper(ChebychevHelper):
754 """
755 This implementation follows https://doi.org/10.1137/120865458.
756 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials.
757 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal.
758 There are also conversion operators from one derivative basis to the next that are sparse.
760 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2.
761 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well.
762 """
764 def get_differentiation_matrix(self, p=1):
765 """
766 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily.
768 Args:
769 p (int): Order of the derivative
771 Returns:
772 sparse differentiation matrix
773 """
774 sp = self.sparse_lib
775 xp = self.xp
776 N = self.N
777 l = p
778 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p
780 def get_S(self, lmbda):
781 """
782 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in
783 https://doi.org/10.1137/120865458.
785 Args:
786 lmbda (int): Ingoing derivative base
788 Returns:
789 sparse matrix: Conversion from derivative base lmbda to lmbda + 1
790 """
791 N = self.N
793 if lmbda == 0:
794 sp = scipy.sparse
795 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil()
796 mat[:, 0] *= 2
797 else:
798 sp = self.sparse_lib
799 xp = self.xp
800 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags(
801 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2
802 )
804 return self.sparse_lib.csc_matrix(mat)
806 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs):
807 """
808 Get a conversion matrix from derivative base `p_in` to `p_out`.
810 Args:
811 p_out (int): Resulting derivative base
812 p_in (int): Ingoing derivative base
813 """
814 mat_fwd = self.sparse_lib.eye(self.N)
815 for i in range(min([p_in, p_out]), max([p_in, p_out])):
816 mat_fwd = self.get_S(i) @ mat_fwd
818 if p_out > p_in:
819 return mat_fwd
821 else:
822 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
823 import scipy.sparse as sp
825 if self.useGPU:
826 mat_fwd = mat_fwd.get()
828 mat_bck = sp.linalg.inv(mat_fwd.tocsc())
830 return self.sparse_lib.csc_matrix(mat_bck)
832 def get_integration_matrix(self):
833 """
834 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the
835 integration constant such that integration starts from x=-1.
837 Example:
839 .. code-block:: python
841 import numpy as np
842 from pySDC.helpers.spectral_helper import UltrasphericalHelper
844 N = 4
845 helper = UltrasphericalHelper(N)
846 coeffs = np.random.random(N)
847 coeffs[-1] = 0
849 poly = np.polynomial.Chebyshev(coeffs)
851 S = helper.get_integration_matrix()
852 U_hat = S @ coeffs
853 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1)
855 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat)
857 Returns:
858 sparse integration matrix
859 """
860 return (
861 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1)
862 @ self.get_basis_change_matrix(p_out=1, p_in=0)
863 * self.lin_trf_fac
864 )
866 def get_integration_constant(self, u_hat, axis):
867 """
868 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details.
870 Args:
871 u_hat: Solution in spectral space
872 axis: Axis you want to integrate over
874 Returns:
875 Integration constant, has one less dimension than `u_hat`
876 """
877 slices = [
878 None,
879 ] * u_hat.ndim
880 slices[axis] = slice(1, u_hat.shape[axis])
881 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis)
884class FFTHelper(SpectralHelper1D):
885 distributable = True
887 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs):
888 """
889 Constructor.
890 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you
891 may choose to run on GPUs via the `useGPU` argument.
893 Args:
894 x0 (float, optional): Coordinate of left boundary
895 x1 (float, optional): Coordinate of right boundary
896 """
897 super().__init__(*args, x0=x0, x1=x1, **kwargs)
899 def get_1dgrid(self):
900 """
901 We use equally spaced points including the left boundary and not including the right one, which is the left boundary.
902 """
903 dx = self.L / self.N
904 return self.xp.arange(self.N) * dx + self.x0
906 def get_wavenumbers(self):
907 """
908 Be careful that this ordering is very unintuitive.
909 """
910 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L
912 def get_differentiation_matrix(self, p=1):
913 """
914 This matrix is diagonal, allowing to invert concurrently.
916 Args:
917 p (int): Order of the derivative
919 Returns:
920 sparse differentiation matrix
921 """
922 k = self.get_wavenumbers()
924 if self.useGPU:
925 if p > 1:
926 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing.
927 from scipy.sparse.linalg import matrix_power
929 D = self.sparse_lib.diags(1j * k).get()
930 return self.sparse_lib.csc_matrix(matrix_power(D, p))
931 else:
932 return self.sparse_lib.diags(1j * k)
933 else:
934 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p)
936 def get_integration_matrix(self, p=1):
937 """
938 Get integration matrix to compute `p`-th integral over the entire domain.
940 Args:
941 p (int): Order of integral you want to compute
943 Returns:
944 sparse integration matrix
945 """
946 k = self.xp.array(self.get_wavenumbers(), dtype='complex128')
947 k[0] = 1j * self.L
948 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p)
950 def get_integration_weights(self):
951 """Weights for integration across entire domain"""
952 weights = self.xp.zeros(self.N)
953 weights[0] = self.L / self.N
954 return weights
956 def get_plan(self, u, forward, *args, **kwargs):
957 if self.fft_lib.__name__ == 'mpi4py_fft.fftw':
958 if 'axes' in kwargs.keys():
959 kwargs['axes'] = tuple(kwargs['axes'])
960 key = (forward, u.shape, args, *(me for me in kwargs.values()))
961 if key in self.plans.keys():
962 return self.plans[key]
963 else:
964 self.logger.debug(f'Generating FFT plan for {key=}')
965 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs)
966 self.plans[key] = transform
968 return self.plans[key]
969 else:
970 if forward:
971 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward'))
972 else:
973 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward'))
975 def transform(self, u, *args, axes=None, shape=None, **kwargs):
976 """
977 FFT along axes. `kwargs` are passed on to the FFT library.
979 Args:
980 u: Data you want to transform
981 axes (tuple): Axes you want to transform over
983 Returns:
984 transformed data
985 """
986 axes = axes if axes else tuple(i for i in range(u.ndim))
987 kwargs['s'] = shape
988 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs)
989 return plan(u, *args, axes=axes, **kwargs)
991 def itransform(self, u, *args, axes=None, shape=None, **kwargs):
992 """
993 Inverse FFT.
995 Args:
996 u: Data you want to transform
997 axes (tuple): Axes over which to transform
999 Returns:
1000 transformed data
1001 """
1002 axes = axes if axes else tuple(i for i in range(u.ndim))
1003 kwargs['s'] = shape
1004 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs)
1005 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes])
1007 def get_BC(self, kind):
1008 """
1009 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`.
1010 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero.
1011 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even.
1013 Args:
1014 kind ('integral' or 'nyquist'): Kind of BC
1016 Returns:
1017 self.xp.ndarray: Boundary condition row
1018 """
1019 if kind.lower() == 'integral':
1020 return self.get_integ_BC_row()
1021 elif kind.lower() == 'nyquist':
1022 assert (
1023 self.N % 2 == 0
1024 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis'
1025 BC = self.xp.zeros(self.N)
1026 BC[self.get_Nyquist_mode_index()] = 1
1027 return BC
1028 else:
1029 return super().get_BC(kind)
1031 def get_Nyquist_mode_index(self):
1032 """
1033 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive
1034 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you
1035 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're
1036 after.
1038 Returns:
1039 int: Index of the Nyquist mode
1040 """
1041 k = self.get_wavenumbers()
1042 Nyquist_mode = min(k)
1043 return self.xp.where(k == Nyquist_mode)[0][0]
1045 def get_integ_BC_row(self):
1046 """
1047 Only the 0-mode has non-zero integral with FFT basis in periodic BCs
1048 """
1049 me = self.xp.zeros(self.N)
1050 me[0] = self.L / self.N
1051 return me
1054class SpectralHelper:
1055 """
1056 This class has three functions:
1057 - Easily assemble matrices containing multiple equations
1058 - Direct product of 1D bases to solve problems in more dimensions
1059 - Distribute the FFTs to facilitate concurrency.
1061 Attributes:
1062 comm (mpi4py.Intracomm): MPI communicator
1063 debug (bool): Perform additional checks at extra computational cost
1064 useGPU (bool): Whether to use GPUs
1065 axes (list): List of 1D bases
1066 components (list): List of strings of the names of components in the equations
1067 full_BCs (list): List of Dictionaries containing all information about the boundary conditions
1068 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from
1069 BCs (sparse matrix): Matrix containing only the BCs
1070 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on
1071 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side
1072 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs
1073 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs`
1074 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space
1075 global_shape (tuple): Global shape of the solution as in `mpi4py-fft`
1076 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft`
1077 init (tuple): This is the same `init` that is used throughout the problem classes
1078 init_forward (tuple): This is the equivalent of `init` in spectral space
1079 """
1081 xp = np
1082 fft_lib = scipy.fft
1083 sparse_lib = scipy.sparse
1084 linalg = scipy.sparse.linalg
1085 dtype = mesh
1086 fft_backend = 'scipy'
1087 fft_comm_backend = 'MPI'
1089 @classmethod
1090 def setup_GPU(cls):
1091 """switch to GPU modules"""
1092 import cupy as cp
1093 import cupyx.scipy.sparse as sparse_lib
1094 import cupyx.scipy.sparse.linalg as linalg
1095 import cupyx.scipy.fft as fft_lib
1096 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
1098 cls.xp = cp
1099 cls.sparse_lib = sparse_lib
1100 cls.linalg = linalg
1102 cls.fft_lib = fft_lib
1103 cls.fft_backend = 'cupyx-scipy'
1104 cls.fft_comm_backend = 'NCCL'
1106 cls.dtype = cupy_mesh
1108 @classmethod
1109 def setup_CPU(cls, useFFTW=False):
1110 """switch to CPU modules"""
1112 cls.xp = np
1113 cls.sparse_lib = scipy.sparse
1114 cls.linalg = scipy.sparse.linalg
1116 if useFFTW:
1117 from mpi4py_fft import fftw
1119 cls.fft_backend = 'fftw'
1120 cls.fft_lib = fftw
1121 else:
1122 cls.fft_backend = 'scipy'
1123 cls.fft_lib = scipy.fft
1125 cls.fft_comm_backend = 'MPI'
1126 cls.dtype = mesh
1128 def __init__(self, comm=None, useGPU=False, debug=False):
1129 """
1130 Constructor
1132 Args:
1133 comm (mpi4py.Intracomm): MPI communicator
1134 useGPU (bool): Whether to use GPUs
1135 debug (bool): Perform additional checks at extra computational cost
1136 """
1137 self.comm = comm
1138 self.debug = debug
1139 self.useGPU = useGPU
1141 if useGPU:
1142 self.setup_GPU()
1143 else:
1144 self.setup_CPU()
1146 self.axes = []
1147 self.components = []
1149 self.full_BCs = []
1150 self.BC_mat = None
1151 self.BCs = None
1153 self.fft_cache = {}
1155 self.logger = logging.getLogger(name='Spectral Discretization')
1156 if debug:
1157 self.logger.setLevel(logging.DEBUG)
1159 @property
1160 def u_init(self):
1161 """
1162 Get empty data container in physical space
1163 """
1164 return self.dtype(self.init)
1166 @property
1167 def u_init_forward(self):
1168 """
1169 Get empty data container in spectral space
1170 """
1171 return self.dtype(self.init_forward)
1173 @property
1174 def u_init_physical(self):
1175 """
1176 Get empty data container in physical space
1177 """
1178 return self.dtype(self.init_physical)
1180 @property
1181 def shape(self):
1182 """
1183 Get shape of individual solution component
1184 """
1185 return self.init[0][1:]
1187 @property
1188 def ndim(self):
1189 return len(self.axes)
1191 @property
1192 def ncomponents(self):
1193 return len(self.components)
1195 @property
1196 def V(self):
1197 """
1198 Get domain volume
1199 """
1200 return np.prod([me.L for me in self.axes])
1202 def add_axis(self, base, *args, **kwargs):
1203 """
1204 Add an axis to the domain by deciding on suitable 1D base.
1205 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D
1206 bases for possible arguments.
1208 Args:
1209 base (str): 1D spectral method
1210 """
1211 kwargs['useGPU'] = self.useGPU
1213 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']:
1214 self.axes.append(ChebychevHelper(*args, **kwargs))
1215 elif base.lower() in ['fft', 'fourier', 'ffthelper']:
1216 self.axes.append(FFTHelper(*args, **kwargs))
1217 elif base.lower() in ['ultraspherical', 'gegenbauer']:
1218 self.axes.append(UltrasphericalHelper(*args, **kwargs))
1219 else:
1220 raise NotImplementedError(f'{base=!r} is not implemented!')
1221 self.axes[-1].xp = self.xp
1222 self.axes[-1].sparse_lib = self.sparse_lib
1224 def add_component(self, name):
1225 """
1226 Add solution component(s).
1228 Args:
1229 name (str or list of strings): Name(s) of component(s)
1230 """
1231 if type(name) in [list, tuple]:
1232 for me in name:
1233 self.add_component(me)
1234 elif type(name) in [str]:
1235 if name in self.components:
1236 raise Exception(f'{name=!r} is already added to this problem!')
1237 self.components.append(name)
1238 else:
1239 raise NotImplementedError
1241 def index(self, name):
1242 """
1243 Get the index of component `name`.
1245 Args:
1246 name (str or list of strings): Name(s) of component(s)
1248 Returns:
1249 int: Index of the component
1250 """
1251 if type(name) in [str, int]:
1252 return self.components.index(name)
1253 elif type(name) in [list, tuple]:
1254 return (self.index(me) for me in name)
1255 else:
1256 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}')
1258 def get_empty_operator_matrix(self, diag=False):
1259 """
1260 Return a matrix of operators to be filled with the connections between the solution components.
1262 Args:
1263 diag (bool): Whether operator is block-diagonal
1265 Returns:
1266 list containing sparse zeros
1267 """
1268 S = len(self.components)
1269 O = self.get_Id() * 0
1270 if diag:
1271 return [O for _ in range(S)]
1272 else:
1273 return [[O for _ in range(S)] for _ in range(S)]
1275 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs):
1276 """
1277 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix.
1278 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines.
1279 Typically, the last line that does not contain a BC is the best choice.
1280 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1282 Args:
1283 axis (int): Axis you want to add the BC to
1284 kind (str): kind of BC, e.g. Dirichlet
1285 line (int): Line you want the BC to go in
1286 scalar (bool): Put the BC in all space positions in the other direction
1288 Returns:
1289 sparse matrix containing the BC
1290 """
1291 sp = scipy.sparse
1293 base = self.axes[axis]
1295 BC = sp.eye(base.N).tolil() * 0
1296 if self.useGPU:
1297 BC[line, :] = base.get_BC(kind=kind, **kwargs).get()
1298 else:
1299 BC[line, :] = base.get_BC(kind=kind, **kwargs)
1301 ndim = len(self.axes)
1302 if ndim == 1:
1303 mat = self.sparse_lib.csc_matrix(BC)
1304 elif ndim == 2:
1305 axis2 = (axis + 1) % ndim
1307 if scalar:
1308 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1)))
1309 else:
1310 _Id = self.axes[axis2].get_Id()
1312 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2)
1314 mats = [
1315 None,
1316 ] * ndim
1317 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1318 mats[axis2] = Id
1319 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats))
1320 elif ndim == 3:
1321 mats = [
1322 None,
1323 ] * ndim
1325 for ax in range(ndim):
1326 if ax == axis:
1327 continue
1329 if scalar:
1330 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1)))
1331 else:
1332 _Id = self.axes[ax].get_Id()
1334 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax)
1336 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis)
1338 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:])))
1339 else:
1340 raise NotImplementedError(
1341 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!'
1342 )
1343 mat = self.eliminate_zeros(mat)
1344 return mat
1346 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs):
1347 """
1348 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively
1349 remove single BCs again, as in incompressible Navier-Stokes, for instance.
1350 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1352 Args:
1353 component (str): Name of the component the BC should act on
1354 equation (str): Name of the equation for the component you want to put the BC in
1355 axis (int): Axis you want to add the BC to
1356 kind (str): kind of BC, e.g. Dirichlet
1357 v: Value of the BC
1358 line (int): Line you want the BC to go in
1359 scalar (bool): Put the BC in all space positions in the other direction
1360 """
1361 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1362 _BC = self.eliminate_zeros(_BC)
1363 self.BC_mat[self.index(equation)][self.index(component)] -= _BC
1365 if scalar:
1366 slices = [self.index(equation)] + [
1367 0,
1368 ] * self.ndim
1369 slices[axis + 1] = line
1370 else:
1371 slices = (
1372 [self.index(equation)]
1373 + [slice(0, self.init[0][i + 1]) for i in range(axis)]
1374 + [line]
1375 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))]
1376 )
1377 N = self.axes[axis].N
1378 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]:
1379 self.BC_rhs_mask[(*slices,)] = False
1381 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs):
1382 """
1383 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a
1384 single sparse matrix by calling `setup_BCs` after adding/removing all BCs.
1385 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details.
1387 Args:
1388 component (str): Name of the component the BC should act on
1389 equation (str): Name of the equation for the component you want to put the BC in
1390 axis (int): Axis you want to add the BC to
1391 kind (str): kind of BC, e.g. Dirichlet
1392 v: Value of the BC
1393 line (int): Line you want the BC to go in
1394 scalar (bool): Put the BC in all space positions in the other direction
1395 """
1396 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs)
1397 self.BC_mat[self.index(equation)][self.index(component)] += _BC
1398 self.full_BCs += [
1399 {
1400 'component': component,
1401 'equation': equation,
1402 'axis': axis,
1403 'kind': kind,
1404 'v': v,
1405 'line': line,
1406 'scalar': scalar,
1407 **kwargs,
1408 }
1409 ]
1411 if scalar:
1412 slices = [self.index(equation)] + [
1413 0,
1414 ] * self.ndim
1415 slices[axis + 1] = line
1416 if self.comm:
1417 if self.comm.rank == 0:
1418 self.BC_rhs_mask[(*slices,)] = True
1419 else:
1420 self.BC_rhs_mask[(*slices,)] = True
1421 else:
1422 slices = [self.index(equation), *self.global_slice(True)]
1423 N = self.axes[axis].N
1424 if (N + line) % N in self.get_indices(True)[axis]:
1425 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1426 self.BC_rhs_mask[(*slices,)] = True
1428 def setup_BCs(self):
1429 """
1430 Convert the list of lists of BCs to the boundary condition operator.
1431 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary
1432 condition. This method sets up a suitable sparse matrix to do this.
1433 """
1434 sp = self.sparse_lib
1435 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat)
1436 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()]
1438 diags = self.xp.ones(self.BCs.shape[0])
1439 diags[self.BC_zero_index] = 0
1440 self.BC_line_zero_matrix = sp.diags(diags).tocsc()
1442 # prepare BCs in spectral space to easily add to the RHS
1443 rhs_BCs = self.put_BCs_in_rhs(self.u_init)
1444 self.rhs_BCs_hat = self.transform(rhs_BCs).view(self.xp.ndarray)
1445 del self.BC_rhs_mask
1447 def check_BCs(self, u):
1448 """
1449 Check that the solution satisfies the boundary conditions
1451 Args:
1452 u: The solution you want to check
1453 """
1454 assert self.ndim < 3
1455 for axis in range(self.ndim):
1456 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]]
1458 if len(BCs) > 0:
1459 u_hat = self.transform(u, axes=(axis - self.ndim,))
1460 for BC in BCs:
1461 kwargs = {
1462 key: value
1463 for key, value in BC.items()
1464 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar']
1465 }
1467 if axis == 0:
1468 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])]
1469 elif axis == 1:
1470 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs)
1471 want = BC['v']
1472 assert self.xp.allclose(
1473 get, want
1474 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}'
1476 def put_BCs_in_matrix(self, A):
1477 """
1478 Put the boundary conditions in a matrix by replacing rows with BCs.
1479 """
1480 return self.BC_line_zero_matrix @ A + self.BCs
1482 def put_BCs_in_rhs_hat(self, rhs_hat):
1483 """
1484 Put the BCs in the right hand side in spectral space for solving.
1485 This function needs no transforms and caches a mask for faster subsequent use.
1487 Args:
1488 rhs_hat: Right hand side in spectral space
1490 Returns:
1491 rhs in spectral space with BCs
1492 """
1493 if not hasattr(self, '_rhs_hat_zero_mask'):
1494 """
1495 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them
1496 by the boundary conditions. The mask is then cached.
1497 """
1498 self._rhs_hat_zero_mask = self.newDistArray(forward_output=True).astype(bool).view(self.xp.ndarray)
1500 for axis in range(self.ndim):
1501 for bc in self.full_BCs:
1502 if axis == bc['axis']:
1503 slices = [self.index(bc['equation']), *self.global_slice(True)]
1504 N = self.axes[axis].N
1505 line = bc['line']
1506 if (N + line) % N in self.get_indices(True)[axis]:
1507 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1508 self._rhs_hat_zero_mask[(*slices,)] = True
1510 rhs_hat[self._rhs_hat_zero_mask] = 0
1511 return rhs_hat + self.rhs_BCs_hat
1513 def put_BCs_in_rhs(self, rhs):
1514 """
1515 Put the BCs in the right hand side for solving.
1516 This function will transform along each axis individually and add all BCs in that axis.
1517 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed.
1519 Args:
1520 rhs: Right hand side in physical space
1522 Returns:
1523 rhs in physical space with BCs
1524 """
1525 assert rhs.ndim > 1, 'rhs must not be flattened here!'
1527 ndim = self.ndim
1529 for axis in range(ndim):
1530 _rhs_hat = self.transform(rhs, axes=(axis - ndim,))
1532 for bc in self.full_BCs:
1534 if axis == bc['axis']:
1535 _slice = [self.index(bc['equation']), *self.global_slice(True)]
1537 N = self.axes[axis].N
1538 line = bc['line']
1539 if (N + line) % N in self.get_indices(True)[axis]:
1540 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start
1541 _rhs_hat[(*_slice,)] = bc['v']
1543 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,))
1545 return rhs
1547 def add_equation_lhs(self, A, equation, relations):
1548 """
1549 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices
1550 that you will convert to an operator later.
1552 Example:
1553 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U
1554 preconditioning:
1556 .. code-block:: python
1557 helper = SpectralHelper()
1559 helper.add_axis(base='chebychev', N=8)
1560 helper.add_component(['u', 'ux'])
1561 helper.setup_fft()
1563 I = helper.get_Id()
1564 Dx = helper.get_differentiation_matrix(axes=(0,))
1565 T2U = helper.get_basis_change_matrix('T2U')
1567 L_lhs = {
1568 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I},
1569 'u': {'ux': -(T2U @ Dx)},
1570 }
1572 operator = helper.get_empty_operator_matrix()
1573 for line, equation in L_lhs.items():
1574 helper.add_equation_lhs(operator, line, equation)
1576 L = helper.convert_operator_matrix_to_operator(operator)
1578 Args:
1579 A (list of lists of sparse matrices): The operator to be
1580 equation (str): The equation of the component you want this in
1581 relations: (dict): Relations between quantities
1582 """
1583 for k, v in relations.items():
1584 A[self.index(equation)][self.index(k)] = v
1586 def eliminate_zeros(self, A):
1587 """
1588 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat.
1589 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`.
1590 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU.
1592 Args:
1593 A: sparse matrix to be pruned
1595 Returns:
1596 CSC sparse matrix
1597 """
1598 if self.useGPU:
1599 A = A.get()
1600 A = A.tocsc()
1601 A.eliminate_zeros()
1602 if self.useGPU:
1603 A = self.sparse_lib.csc_matrix(A)
1604 return A
1606 def convert_operator_matrix_to_operator(self, M):
1607 """
1608 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator.
1609 See documentation of `SpectralHelper.add_equation_lhs` for an example.
1611 Args:
1612 M (list of lists of sparse matrices): The operator to be
1614 Returns:
1615 sparse linear operator
1616 """
1617 if len(self.components) == 1:
1618 op = M[0][0]
1619 else:
1620 op = self.sparse_lib.bmat(M, format='csc')
1622 op = self.eliminate_zeros(op)
1623 return op
1625 def get_wavenumbers(self):
1626 """
1627 Get grid in spectral space
1628 """
1629 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))]
1630 return self.xp.meshgrid(*grids, indexing='ij')
1632 def get_grid(self, forward_output=False):
1633 """
1634 Get grid in physical space
1635 """
1636 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1637 return self.xp.meshgrid(*grids, indexing='ij')
1639 def get_indices(self, forward_output=True):
1640 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))]
1642 @cache
1643 def get_pfft(self, axes=None, padding=None, grid=None):
1644 if self.ndim == 1 or self.comm is None:
1645 return None
1646 from mpi4py_fft import PFFT, newDistArray
1648 axes = tuple(i for i in range(self.ndim)) if axes is None else axes
1649 padding = list(padding if padding else [1.0 for _ in range(self.ndim)])
1651 def no_transform(u, *args, **kwargs):
1652 return u
1654 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)}
1655 for i in axes:
1656 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform)
1658 # "transform" all axes to ensure consistent shapes.
1659 # Transform non-distributable axes last to ensure they are aligned
1660 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes))
1661 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted(
1662 [axis for axis in _axes if self.axes[axis].distributable]
1663 + [axis for axis in range(self.ndim) if axis not in _axes]
1664 )
1666 pfft = PFFT(
1667 comm=self.comm,
1668 shape=self.global_shape[1:],
1669 axes=_axes, # TODO: control the order of the transforms better
1670 dtype='D',
1671 collapse=False,
1672 backend=self.fft_backend,
1673 comm_backend=self.fft_comm_backend,
1674 padding=padding,
1675 transforms=transforms,
1676 grid=grid,
1677 )
1679 # do a transform to do the planning
1680 _u = newDistArray(pfft, forward_output=False)
1681 pfft.backward(pfft.forward(_u))
1682 return pfft
1684 def get_fft(self, axes=None, direction='object', padding=None, shape=None):
1685 """
1686 When using MPI, we use `PFFT` objects generated by mpi4py-fft
1688 Args:
1689 axes (tuple): Axes you want to transform over
1690 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object
1691 padding (tuple): Padding for dealiasing
1692 shape (tuple): Shape of the transform
1694 Returns:
1695 transform
1696 """
1697 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
1698 shape = self.global_shape[1:] if shape is None else shape
1699 padding = (
1700 [
1701 1,
1702 ]
1703 * self.ndim
1704 if padding is None
1705 else padding
1706 )
1707 key = (axes, direction, tuple(padding), tuple(shape))
1709 if key not in self.fft_cache.keys():
1710 if self.comm is None:
1711 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms'
1713 if direction == 'forward':
1714 self.fft_cache[key] = self.xp.fft.fftn
1715 elif direction == 'backward':
1716 self.fft_cache[key] = self.xp.fft.ifftn
1717 elif direction == 'object':
1718 self.fft_cache[key] = None
1719 else:
1720 if direction == 'object':
1721 from mpi4py_fft import PFFT
1723 _fft = PFFT(
1724 comm=self.comm,
1725 shape=shape,
1726 axes=sorted(axes),
1727 dtype='D',
1728 collapse=False,
1729 backend=self.fft_backend,
1730 comm_backend=self.fft_comm_backend,
1731 padding=padding,
1732 )
1733 else:
1734 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape)
1736 if direction == 'forward':
1737 self.fft_cache[key] = _fft.forward
1738 elif direction == 'backward':
1739 self.fft_cache[key] = _fft.backward
1740 elif direction == 'object':
1741 self.fft_cache[key] = _fft
1743 return self.fft_cache[key]
1745 def local_slice(self, forward_output=True):
1746 if self.fft_obj:
1747 return self.get_pfft().local_slice(forward_output=forward_output)
1748 else:
1749 return [slice(0, me.N) for me in self.axes]
1751 def global_slice(self, forward_output=True):
1752 if self.fft_obj:
1753 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)]
1754 else:
1755 return self.local_slice(forward_output=forward_output)
1757 def setup_fft(self, real_spectral_coefficients=False):
1758 """
1759 This function must be called after all axes have been setup in order to prepare the local shapes of the data.
1760 This must also be called before setting up any BCs.
1762 Args:
1763 real_spectral_coefficients (bool): Allow only real coefficients in spectral space
1764 """
1765 if len(self.components) == 0:
1766 self.add_component('u')
1768 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes)
1770 axes = tuple(i for i in range(len(self.axes)))
1771 self.fft_obj = self.get_pfft(axes=axes)
1773 self.init = (
1774 np.empty(shape=self.global_shape)[
1775 (
1776 ...,
1777 *self.local_slice(False),
1778 )
1779 ].shape,
1780 self.comm,
1781 np.dtype('float'),
1782 )
1783 self.init_physical = (
1784 np.empty(shape=self.global_shape)[
1785 (
1786 ...,
1787 *self.local_slice(False),
1788 )
1789 ].shape,
1790 self.comm,
1791 np.dtype('float'),
1792 )
1793 self.init_forward = (
1794 np.empty(shape=self.global_shape)[
1795 (
1796 ...,
1797 *self.local_slice(True),
1798 )
1799 ].shape,
1800 self.comm,
1801 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'),
1802 )
1804 self.BC_mat = self.get_empty_operator_matrix()
1805 self.BC_rhs_mask = self.newDistArray().astype(bool)
1807 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False):
1808 """
1809 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but
1810 takes care of all the solution components in the tensor.
1811 """
1812 if self.comm is None:
1813 return self.xp.zeros(self.init[0], dtype=self.init[2])
1814 from mpi4py_fft.distarray import DistArray
1816 pfft = pfft if pfft else self.get_pfft()
1817 if pfft is None:
1818 if forward_output:
1819 return self.u_init_forward
1820 else:
1821 return self.u_init
1823 global_shape = pfft.global_shape(forward_output)
1824 p0 = pfft.pencil[forward_output]
1825 if forward_output is True:
1826 dtype = pfft.forward.output_array.dtype
1827 else:
1828 dtype = pfft.forward.input_array.dtype
1829 global_shape = (self.ncomponents,) * rank + global_shape
1831 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]:
1832 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls
1833 else:
1834 darraycls = DistArray
1836 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank)
1837 return z.v if view else z
1839 def infer_alignment(self, u, forward_output, padding=None, **kwargs):
1840 if self.comm is None:
1841 return [0]
1843 def _alignment(pfft):
1844 _arr = self.newDistArray(pfft, forward_output=forward_output)
1845 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]]
1846 return _aligned_axes
1848 if padding is None:
1849 pfft = self.get_pfft(**kwargs)
1850 aligned_axes = _alignment(pfft)
1851 else:
1852 if self.ndim == 2:
1853 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)]
1854 elif self.ndim == 3:
1855 padding_options = [
1856 (1.0, 1.0, padding[2]),
1857 (1.0, padding[1], 1.0),
1858 (padding[0], 1.0, 1.0),
1859 (1.0, padding[1], padding[2]),
1860 (padding[0], 1.0, padding[2]),
1861 (padding[0], padding[1], 1.0),
1862 padding,
1863 (1.0, 1.0, 1.0),
1864 ]
1865 else:
1866 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!')
1867 for _padding in padding_options:
1868 pfft = self.get_pfft(padding=_padding, **kwargs)
1869 aligned_axes = _alignment(pfft)
1870 if len(aligned_axes) > 0:
1871 self.logger.debug(
1872 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}'
1873 )
1874 break
1876 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!'
1877 return aligned_axes
1879 def redistribute(self, u, axis, forward_output, **kwargs):
1880 if self.comm is None:
1881 return u
1883 pfft = self.get_pfft(**kwargs)
1884 _arr = self.newDistArray(pfft, forward_output=forward_output)
1886 if 'Dist' in type(u).__name__ and False:
1887 try:
1888 u.redistribute(out=_arr)
1889 return _arr
1890 except AssertionError:
1891 pass
1893 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs)
1894 for alignment in u_alignment:
1895 _arr = _arr.redistribute(alignment)
1896 if _arr.shape == u.shape:
1897 _arr[...] = u
1898 return _arr.redistribute(axis)
1900 raise Exception(
1901 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}'
1902 )
1904 def transform(self, u, *args, axes=None, padding=None, **kwargs):
1905 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1907 if pfft is None:
1908 axes = axes if axes else tuple(i for i in range(self.ndim))
1909 u_hat = u.copy()
1910 for i in axes:
1911 _axis = 1 + i if i >= 0 else i
1912 u_hat = self.axes[i].transform(u_hat, axes=(_axis,))
1913 return u_hat
1915 _in = self.newDistArray(pfft, forward_output=False, rank=1)
1916 _out = self.newDistArray(pfft, forward_output=True, rank=1)
1918 if _in.shape == u.shape:
1919 _in[...] = u
1920 else:
1921 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs)
1923 for i in range(self.ncomponents):
1924 pfft.forward(_in[i], _out[i], normalize=False)
1926 if padding is not None:
1927 _out /= np.prod(padding)
1928 return _out
1930 def itransform(self, u, *args, axes=None, padding=None, **kwargs):
1931 if padding is not None:
1932 assert all(
1933 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim)
1934 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer'
1936 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs)
1937 if pfft is None:
1938 axes = axes if axes else tuple(i for i in range(self.ndim))
1939 u_hat = u.copy()
1940 for i in axes:
1941 _axis = 1 + i if i >= 0 else i
1942 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,))
1943 return u_hat
1945 _in = self.newDistArray(pfft, forward_output=True, rank=1)
1946 _out = self.newDistArray(pfft, forward_output=False, rank=1)
1948 if _in.shape == u.shape:
1949 _in[...] = u
1950 else:
1951 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs)
1953 for i in range(self.ncomponents):
1954 pfft.backward(_in[i], _out[i], normalize=True)
1956 if padding is not None:
1957 _out *= np.prod(padding)
1958 return _out
1960 def get_local_slice_of_1D_matrix(self, M, axis):
1961 """
1962 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes,
1963 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can
1964 use this method to get the part corresponding to the modes carried by this rank.
1966 Args:
1967 M (sparse matrix): Global 1D matrix you want to get the local version of
1968 axis (int): Direction in which you want the local version. You will get the global matrix in other directions.
1970 Returns:
1971 sparse local matrix
1972 """
1973 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]]
1975 def expand_matrix_ND(self, matrix, aligned):
1976 sp = self.sparse_lib
1977 axes = np.delete(np.arange(self.ndim), aligned)
1978 ndim = len(axes) + 1
1980 if ndim == 1:
1981 mat = matrix
1982 elif ndim == 2:
1983 axis = axes[0]
1984 I1D = sp.eye(self.axes[axis].N)
1986 mats = [None] * ndim
1987 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1988 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1990 mat = sp.kron(*mats)
1991 elif ndim == 3:
1993 mats = [None] * ndim
1994 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
1995 for axis in axes:
1996 I1D = sp.eye(self.axes[axis].N)
1997 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
1999 mat = sp.kron(mats[0], sp.kron(*mats[1:]))
2001 else:
2002 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
2004 mat = self.eliminate_zeros(mat)
2005 return mat
2007 def get_filter_matrix(self, axis, **kwargs):
2008 """
2009 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
2010 admissible.
2012 Returns:
2013 sparse bandpass matrix
2014 """
2015 if self.ndim == 1:
2016 return self.axes[0].get_filter_matrix(**kwargs)
2018 mats = [base.get_Id() for base in self.axes]
2019 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs)
2020 return self.sparse_lib.kron(*mats)
2022 def get_differentiation_matrix(self, axes, **kwargs):
2023 """
2024 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation.
2026 Args:
2027 axes (tuple): Axes along which to differentiate.
2029 Returns:
2030 sparse differentiation matrix
2031 """
2032 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
2033 for axis in axes[1:]:
2034 _D = self.axes[axis].get_differentiation_matrix(**kwargs)
2035 D = D @ self.expand_matrix_ND(_D, axis)
2037 self.logger.debug(f'Set up differentiation matrix along axes {axes} with kwargs {kwargs}')
2038 return D
2040 def get_integration_matrix(self, axes):
2041 """
2042 Get integration matrix to integrate along specified axis.
2044 Args:
2045 axes (tuple): Axes along which to integrate over.
2047 Returns:
2048 sparse integration matrix
2049 """
2050 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
2051 for axis in axes[1:]:
2052 _S = self.axes[axis].get_integration_matrix()
2053 S = S @ self.expand_matrix_ND(_S, axis)
2055 return S
2057 def get_Id(self):
2058 """
2059 Get identity matrix
2061 Returns:
2062 sparse identity matrix
2063 """
2064 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
2065 for axis in range(1, self.ndim):
2066 _I = self.axes[axis].get_Id()
2067 I = I @ self.expand_matrix_ND(_I, axis)
2068 return I
2070 def get_Dirichlet_recombination_matrix(self, axis=-1):
2071 """
2072 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases.
2074 Args:
2075 axis (int): Axis you discretized with Chebychev
2077 Returns:
2078 sparse matrix
2079 """
2080 C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
2081 return self.expand_matrix_ND(C1D, axis)
2083 def get_basis_change_matrix(self, axes=None, **kwargs):
2084 """
2085 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want.
2086 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`.
2088 Args:
2089 axes (tuple): Axes along which to change basis.
2091 Returns:
2092 sparse basis change matrix
2093 """
2094 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
2096 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
2097 for axis in axes[1:]:
2098 _C = self.axes[axis].get_basis_change_matrix(**kwargs)
2099 C = C @ self.expand_matrix_ND(_C, axis)
2101 self.logger.debug(f'Set up basis change matrix along axes {axes} with kwargs {kwargs}')
2102 return C