Coverage for pySDC/helpers/spectral_helper.py: 90%

765 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 11:36 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5from functools import partial, wraps 

6import logging 

7 

8 

9def cache(func): 

10 """ 

11 Decorator for caching return values of functions. 

12 This is very similar to `functools.cache`, but without the memory leaks (see 

13 https://docs.astral.sh/ruff/rules/cached-instance-method/). 

14 

15 Example: 

16 

17 .. code-block:: python 

18 

19 num_calls = 0 

20 

21 @cache 

22 def increment(x): 

23 num_calls += 1 

24 return x + 1 

25 

26 increment(0) # returns 1, num_calls = 1 

27 increment(1) # returns 2, num_calls = 2 

28 increment(0) # returns 1, num_calls = 2 

29 

30 

31 Args: 

32 func (function): The function you want to cache the return value of 

33 

34 Returns: 

35 return value of func 

36 """ 

37 attr_cache = f"_{func.__name__}_cache" 

38 

39 @wraps(func) 

40 def wrapper(self, *args, **kwargs): 

41 if not hasattr(self, attr_cache): 

42 setattr(self, attr_cache, {}) 

43 

44 cache = getattr(self, attr_cache) 

45 

46 key = (args, frozenset(kwargs.items())) 

47 if key in cache: 

48 return cache[key] 

49 result = func(self, *args, **kwargs) 

50 cache[key] = result 

51 return result 

52 

53 return wrapper 

54 

55 

56class SpectralHelper1D: 

57 """ 

58 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

59 all bases need to have. 

60 

61 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

62 the code for GPUs. 

63 

64 Attributes: 

65 N (int): Resolution 

66 x0 (float): Coordinate of left boundary 

67 x1 (float): Coordinate of right boundary 

68 L (float): Length of the domain 

69 useGPU (bool): Whether to use GPUs 

70 

71 """ 

72 

73 fft_lib = scipy.fft 

74 sparse_lib = scipy.sparse 

75 linalg = scipy.sparse.linalg 

76 xp = np 

77 distributable = False 

78 

79 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False): 

80 """ 

81 Constructor 

82 

83 Args: 

84 N (int): Resolution 

85 x0 (float): Coordinate of left boundary 

86 x1 (float): Coordinate of right boundary 

87 useGPU (bool): Whether to use GPUs 

88 useFFTW (bool): Whether to use FFTW for the transforms 

89 """ 

90 self.N = N 

91 self.x0 = x0 

92 self.x1 = x1 

93 self.L = x1 - x0 

94 self.useGPU = useGPU 

95 self.plans = {} 

96 self.logger = logging.getLogger(name=type(self).__name__) 

97 

98 if useGPU: 

99 self.setup_GPU() 

100 else: 

101 self.setup_CPU(useFFTW=useFFTW) 

102 

103 if useGPU and useFFTW: 

104 raise ValueError('Please run either on GPUs or with FFTW, not both!') 

105 

106 @classmethod 

107 def setup_GPU(cls): 

108 """switch to GPU modules""" 

109 import cupy as cp 

110 import cupyx.scipy.sparse as sparse_lib 

111 import cupyx.scipy.sparse.linalg as linalg 

112 import cupyx.scipy.fft as fft_lib 

113 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

114 

115 cls.xp = cp 

116 cls.sparse_lib = sparse_lib 

117 cls.linalg = linalg 

118 cls.fft_lib = fft_lib 

119 

120 @classmethod 

121 def setup_CPU(cls, useFFTW=False): 

122 """switch to CPU modules""" 

123 

124 cls.xp = np 

125 cls.sparse_lib = scipy.sparse 

126 cls.linalg = scipy.sparse.linalg 

127 

128 if useFFTW: 

129 from mpi4py_fft import fftw 

130 

131 cls.fft_backend = 'fftw' 

132 cls.fft_lib = fftw 

133 else: 

134 cls.fft_backend = 'scipy' 

135 cls.fft_lib = scipy.fft 

136 

137 cls.fft_comm_backend = 'MPI' 

138 cls.dtype = mesh 

139 

140 def get_Id(self): 

141 """ 

142 Get identity matrix 

143 

144 Returns: 

145 sparse diagonal identity matrix 

146 """ 

147 return self.sparse_lib.eye(self.N) 

148 

149 def get_zero(self): 

150 """ 

151 Get a matrix with all zeros of the correct size. 

152 

153 Returns: 

154 sparse matrix with zeros everywhere 

155 """ 

156 return 0 * self.get_Id() 

157 

158 def get_differentiation_matrix(self): 

159 raise NotImplementedError() 

160 

161 def get_integration_matrix(self): 

162 raise NotImplementedError() 

163 

164 def get_wavenumbers(self): 

165 """ 

166 Get the grid in spectral space 

167 """ 

168 raise NotImplementedError 

169 

170 def get_empty_operator_matrix(self, S, O): 

171 """ 

172 Return a matrix of operators to be filled with the connections between the solution components. 

173 

174 Args: 

175 S (int): Number of components in the solution 

176 O (sparse matrix): Zero matrix used for initialization 

177 

178 Returns: 

179 list of lists containing sparse zeros 

180 """ 

181 return [[O for _ in range(S)] for _ in range(S)] 

182 

183 def get_basis_change_matrix(self, *args, **kwargs): 

184 """ 

185 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

186 between the various bases. 

187 

188 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

189 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

190 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

191 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

192 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

193 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

194 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

195 

196 Returns: 

197 sparse bases change matrix 

198 """ 

199 return self.sparse_lib.eye(self.N) 

200 

201 def get_BC(self, kind): 

202 """ 

203 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

204 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

205 the matrix are replaced by the boundary conditions as obtained by this method. 

206 

207 Args: 

208 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

209 individual 1D bases for what is implemented 

210 

211 Returns: 

212 self.xp.array: Boundary condition 

213 """ 

214 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

215 

216 def get_filter_matrix(self, kmin=0, kmax=None): 

217 """ 

218 Get a bandpass filter. 

219 

220 Args: 

221 kmin (int): Lower limit of the bandpass filter 

222 kmax (int): Upper limit of the bandpass filter 

223 

224 Returns: 

225 sparse matrix 

226 """ 

227 

228 k = abs(self.get_wavenumbers()) 

229 

230 kmax = max(k) if kmax is None else kmax 

231 

232 mask = self.xp.logical_or(k >= kmax, k < kmin) 

233 

234 if self.useGPU: 

235 Id = self.get_Id().get() 

236 else: 

237 Id = self.get_Id() 

238 F = Id.tolil() 

239 F[:, mask] = 0 

240 return F.tocsc() 

241 

242 def get_1dgrid(self): 

243 """ 

244 Get the grid in physical space 

245 

246 Returns: 

247 self.xp.array: Grid 

248 """ 

249 raise NotImplementedError 

250 

251 

252class ChebychevHelper(SpectralHelper1D): 

253 """ 

254 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

255 between physical and spectral space by discrete cosine transform. 

256 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

257 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

258 be formulated in first order formulation. 

259 

260 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068). 

261 """ 

262 

263 def __init__(self, *args, x0=-1, x1=1, **kwargs): 

264 """ 

265 Constructor. 

266 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

267 may choose to run on GPUs via the `useGPU` argument. 

268 

269 Args: 

270 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

271 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

272 """ 

273 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

274 self.lin_trf_fac = (x1 - x0) / 2 

275 self.lin_trf_off = (x1 + x0) / 2 

276 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

277 

278 self.norm = self.get_norm() 

279 

280 def get_1dgrid(self): 

281 ''' 

282 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

283 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

284 need to do an affine transformation before any Chebychev business. 

285 

286 Returns: 

287 numpy.ndarray: 1D grid 

288 ''' 

289 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

290 

291 def get_wavenumbers(self): 

292 """Get the domain in spectral space""" 

293 return self.xp.arange(self.N) 

294 

295 @cache 

296 def get_conv(self, name, N=None): 

297 ''' 

298 Get conversion matrix between different kinds of polynomials. The supported kinds are 

299 - T: Chebychev polynomials of first kind 

300 - U: Chebychev polynomials of second kind 

301 - D: Dirichlet recombination. 

302 

303 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

304 Once generates matrices are cached. So feel free to call the method as often as you like. 

305 

306 Args: 

307 name (str): Conversion code, e.g. 'T2U' 

308 N (int): Size of the matrix (optional) 

309 

310 Returns: 

311 scipy.sparse: Sparse conversion matrix 

312 ''' 

313 N = N if N else self.N 

314 sp = self.sparse_lib 

315 xp = self.xp 

316 

317 def get_forward_conv(name): 

318 if name == 'T2U': 

319 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)).tocsc() / 2.0 

320 mat[:, 0] *= 2 

321 elif name == 'D2T': 

322 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

323 elif name[0] == name[-1]: 

324 mat = self.sparse_lib.eye(self.N) 

325 else: 

326 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

327 return mat 

328 

329 try: 

330 mat = get_forward_conv(name) 

331 except NotImplementedError as E: 

332 try: 

333 fwd = get_forward_conv(name[::-1]) 

334 import scipy.sparse as sp 

335 

336 if self.sparse_lib == sp: 

337 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

338 else: 

339 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

340 except NotImplementedError: 

341 raise NotImplementedError from E 

342 

343 return mat 

344 

345 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

346 """ 

347 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

348 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

349 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

350 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

351 

352 Args: 

353 conv (str): Conversion code, i.e. T2U 

354 

355 Returns: 

356 Sparse conversion matrix 

357 """ 

358 return self.get_conv(conv) 

359 

360 def get_integration_matrix(self, lbnd=0): 

361 """ 

362 Get matrix for integration 

363 

364 Args: 

365 lbnd (float): Lower bound for integration, only 0 is currently implemented 

366 

367 Returns: 

368 Sparse integration matrix 

369 """ 

370 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

371 n = self.xp.arange(self.N) 

372 if lbnd == 0: 

373 S = S.tocsc() 

374 S[0, 1::2] = ( 

375 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

376 * (-1) ** (self.xp.arange(self.N // 2)) 

377 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

378 ) * self.lin_trf_fac 

379 else: 

380 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

381 return S 

382 

383 def get_differentiation_matrix(self, p=1): 

384 ''' 

385 Keep in mind that the T2T differentiation matrix is dense. 

386 

387 Args: 

388 p (int): Derivative you want to compute 

389 

390 Returns: 

391 numpy.ndarray: Differentiation matrix 

392 ''' 

393 D = self.xp.zeros((self.N, self.N)) 

394 for j in range(self.N): 

395 for k in range(j): 

396 D[k, j] = 2 * j * ((j - k) % 2) 

397 

398 D[0, :] /= 2 

399 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

400 

401 @cache 

402 def get_norm(self, N=None): 

403 ''' 

404 Get normalization for converting Chebychev coefficients and DCT 

405 

406 Args: 

407 N (int, optional): Resolution 

408 

409 Returns: 

410 self.xp.array: Normalization 

411 ''' 

412 N = self.N if N is None else N 

413 norm = self.xp.ones(N) / N 

414 norm[0] /= 2 

415 return norm 

416 

417 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

418 """ 

419 DCT along axes. `kwargs` will be passed on to the FFT library. 

420 

421 Args: 

422 u: Data you want to transform 

423 axes (tuple): Axes you want to transform along 

424 

425 Returns: 

426 Data in spectral space 

427 """ 

428 axes = axes if axes else tuple(i for i in range(u.ndim)) 

429 kwargs['s'] = shape 

430 kwargs['norm'] = kwargs.get('norm', 'backward') 

431 

432 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs) 

433 for axis in axes: 

434 

435 if self.N < trf.shape[axis]: 

436 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

437 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies 

438 # in the middle. 

439 _trf = self.xp.zeros_like(trf) 

440 N = self.N 

441 N_pad = _trf.shape[axis] - N 

442 end_first_half = N // 2 + 1 

443 

444 # copy first "half" 

445 su = [slice(None)] * trf.ndim 

446 su[axis] = slice(0, end_first_half) 

447 _trf[tuple(su)] = trf[tuple(su)] 

448 

449 # copy second "half" 

450 su = [slice(None)] * u.ndim 

451 su[axis] = slice(end_first_half + N_pad, None) 

452 s_u = [slice(None)] * u.ndim 

453 s_u[axis] = slice(end_first_half, N) 

454 _trf[tuple(su)] = trf[tuple(s_u)] 

455 

456 # # copy values to be cut 

457 # su = [slice(None)] * u.ndim 

458 # su[axis] = slice(end_first_half, end_first_half + N_pad) 

459 # s_u = [slice(None)] * u.ndim 

460 # s_u[axis] = slice(-N_pad, None) 

461 # _trf[tuple(su)] = trf[tuple(s_u)] 

462 

463 trf = _trf 

464 

465 expansion = [np.newaxis for _ in u.shape] 

466 expansion[axis] = slice(0, u.shape[axis], 1) 

467 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1] 

468 norm[: self.N] = self.norm 

469 trf *= norm[(*expansion,)] 

470 return trf 

471 

472 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

473 """ 

474 Inverse DCT along axis. 

475 

476 Args: 

477 u: Data you want to transform 

478 axes (tuple): Axes you want to transform along 

479 

480 Returns: 

481 Data in physical space 

482 """ 

483 axes = axes if axes else tuple(i for i in range(u.ndim)) 

484 kwargs['s'] = shape 

485 kwargs['norm'] = kwargs.get('norm', 'backward') 

486 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False) 

487 

488 for axis in axes: 

489 

490 if self.N == u.shape[axis]: 

491 _u = u.copy() 

492 else: 

493 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

494 # added in the middle rather than the end. We need to resort this here and put the padding in the end. 

495 N = self.N 

496 _u = self.xp.zeros_like(u) 

497 

498 # copy first half 

499 su = [slice(None)] * u.ndim 

500 su[axis] = slice(0, N // 2 + 1) 

501 _u[tuple(su)] = u[tuple(su)] 

502 

503 # copy second half 

504 su = [slice(None)] * u.ndim 

505 su[axis] = slice(-(N // 2), None) 

506 s_u = [slice(None)] * u.ndim 

507 s_u[axis] = slice(N // 2, N // 2 + (N // 2)) 

508 _u[tuple(s_u)] = u[tuple(su)] 

509 

510 if N % 2 == 0: 

511 su = [slice(None)] * u.ndim 

512 su[axis] = N // 2 

513 _u[tuple(su)] *= 2 

514 

515 # generate norm 

516 expansion = [np.newaxis for _ in u.shape] 

517 expansion[axis] = slice(0, u.shape[axis], 1) 

518 norm = self.xp.ones(_u.shape[axis]) 

519 norm[: self.N] = self.norm 

520 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N 

521 

522 _u /= norm[(*expansion,)] 

523 

524 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs) 

525 

526 def get_BC(self, kind, **kwargs): 

527 """ 

528 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

529 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

530 set the BC. 

531 

532 Args: 

533 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

534 """ 

535 if kind.lower() == 'integral': 

536 return self.get_integ_BC_row(**kwargs) 

537 elif kind.lower() == 'dirichlet': 

538 return self.get_Dirichlet_BC_row(**kwargs) 

539 else: 

540 return super().get_BC(kind) 

541 

542 def get_integ_BC_row(self): 

543 """ 

544 Get a row for generating integral BCs with T polynomials. 

545 It returns the values of the integrals of T polynomials over the entire interval. 

546 

547 Returns: 

548 self.xp.ndarray: Row to put into a matrix 

549 """ 

550 n = self.xp.arange(self.N) + 1 

551 me = self.xp.zeros_like(n).astype(float) 

552 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

553 me[0] = 2.0 

554 return me 

555 

556 def get_Dirichlet_BC_row(self, x): 

557 """ 

558 Get a row for generating Dirichlet BCs at x with T polynomials. 

559 It returns the values of the T polynomials at x. 

560 

561 Args: 

562 x (float): Position of the boundary condition 

563 

564 Returns: 

565 self.xp.ndarray: Row to put into a matrix 

566 """ 

567 if x == -1: 

568 return (-1) ** self.xp.arange(self.N) 

569 elif x == 1: 

570 return self.xp.ones(self.N) 

571 elif x == 0: 

572 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

573 n[2::4] *= -1 

574 return n 

575 else: 

576 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

577 

578 def get_Dirichlet_recombination_matrix(self): 

579 ''' 

580 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

581 This makes for a good right preconditioner. 

582 

583 Returns: 

584 scipy.sparse: Sparse conversion matrix 

585 ''' 

586 N = self.N 

587 sp = self.sparse_lib 

588 xp = self.xp 

589 

590 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

591 

592 

593class UltrasphericalHelper(ChebychevHelper): 

594 """ 

595 This implementation follows https://doi.org/10.1137/120865458. 

596 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

597 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

598 There are also conversion operators from one derivative basis to the next that are sparse. 

599 

600 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

601 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

602 """ 

603 

604 def get_differentiation_matrix(self, p=1): 

605 """ 

606 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

607 

608 Args: 

609 p (int): Order of the derivative 

610 

611 Returns: 

612 sparse differentiation matrix 

613 """ 

614 sp = self.sparse_lib 

615 xp = self.xp 

616 N = self.N 

617 l = p 

618 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

619 

620 def get_S(self, lmbda): 

621 """ 

622 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

623 https://doi.org/10.1137/120865458. 

624 

625 Args: 

626 lmbda (int): Ingoing derivative base 

627 

628 Returns: 

629 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

630 """ 

631 N = self.N 

632 

633 if lmbda == 0: 

634 sp = scipy.sparse 

635 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil() 

636 mat[:, 0] *= 2 

637 else: 

638 sp = self.sparse_lib 

639 xp = self.xp 

640 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

641 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

642 ) 

643 

644 return self.sparse_lib.csc_matrix(mat) 

645 

646 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

647 """ 

648 Get a conversion matrix from derivative base `p_in` to `p_out`. 

649 

650 Args: 

651 p_out (int): Resulting derivative base 

652 p_in (int): Ingoing derivative base 

653 """ 

654 mat_fwd = self.sparse_lib.eye(self.N) 

655 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

656 mat_fwd = self.get_S(i) @ mat_fwd 

657 

658 if p_out > p_in: 

659 return mat_fwd 

660 

661 else: 

662 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

663 import scipy.sparse as sp 

664 

665 if self.useGPU: 

666 mat_fwd = mat_fwd.get() 

667 

668 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

669 

670 return self.sparse_lib.csc_matrix(mat_bck) 

671 

672 def get_integration_matrix(self): 

673 """ 

674 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

675 integration constant such that integration starts from x=-1. 

676 

677 Example: 

678 

679 .. code-block:: python 

680 

681 import numpy as np 

682 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

683 

684 N = 4 

685 helper = UltrasphericalHelper(N) 

686 coeffs = np.random.random(N) 

687 coeffs[-1] = 0 

688 

689 poly = np.polynomial.Chebyshev(coeffs) 

690 

691 S = helper.get_integration_matrix() 

692 U_hat = S @ coeffs 

693 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

694 

695 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

696 

697 Returns: 

698 sparse integration matrix 

699 """ 

700 return ( 

701 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

702 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

703 * self.lin_trf_fac 

704 ) 

705 

706 def get_integration_constant(self, u_hat, axis): 

707 """ 

708 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

709 

710 Args: 

711 u_hat: Solution in spectral space 

712 axis: Axis you want to integrate over 

713 

714 Returns: 

715 Integration constant, has one less dimension than `u_hat` 

716 """ 

717 slices = [ 

718 None, 

719 ] * u_hat.ndim 

720 slices[axis] = slice(1, u_hat.shape[axis]) 

721 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

722 

723 

724class FFTHelper(SpectralHelper1D): 

725 distributable = True 

726 

727 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

728 """ 

729 Constructor. 

730 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

731 may choose to run on GPUs via the `useGPU` argument. 

732 

733 Args: 

734 x0 (float, optional): Coordinate of left boundary 

735 x1 (float, optional): Coordinate of right boundary 

736 """ 

737 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

738 

739 def get_1dgrid(self): 

740 """ 

741 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

742 """ 

743 dx = self.L / self.N 

744 return self.xp.arange(self.N) * dx + self.x0 

745 

746 def get_wavenumbers(self): 

747 """ 

748 Be careful that this ordering is very unintuitive. 

749 """ 

750 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

751 

752 def get_differentiation_matrix(self, p=1): 

753 """ 

754 This matrix is diagonal, allowing to invert concurrently. 

755 

756 Args: 

757 p (int): Order of the derivative 

758 

759 Returns: 

760 sparse differentiation matrix 

761 """ 

762 k = self.get_wavenumbers() 

763 

764 if self.useGPU: 

765 if p > 1: 

766 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

767 from scipy.sparse.linalg import matrix_power 

768 

769 D = self.sparse_lib.diags(1j * k).get() 

770 return self.sparse_lib.csc_matrix(matrix_power(D, p)) 

771 else: 

772 return self.sparse_lib.diags(1j * k) 

773 else: 

774 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

775 

776 def get_integration_matrix(self, p=1): 

777 """ 

778 Get integration matrix to compute `p`-th integral over the entire domain. 

779 

780 Args: 

781 p (int): Order of integral you want to compute 

782 

783 Returns: 

784 sparse integration matrix 

785 """ 

786 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

787 k[0] = 1j * self.L 

788 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

789 

790 def get_plan(self, u, forward, *args, **kwargs): 

791 if self.fft_lib.__name__ == 'mpi4py_fft.fftw': 

792 if 'axes' in kwargs.keys(): 

793 kwargs['axes'] = tuple(kwargs['axes']) 

794 key = (forward, u.shape, args, *(me for me in kwargs.values())) 

795 if key in self.plans.keys(): 

796 return self.plans[key] 

797 else: 

798 self.logger.debug(f'Generating FFT plan for {key=}') 

799 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs) 

800 self.plans[key] = transform 

801 

802 return self.plans[key] 

803 else: 

804 if forward: 

805 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward')) 

806 else: 

807 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward')) 

808 

809 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

810 """ 

811 FFT along axes. `kwargs` are passed on to the FFT library. 

812 

813 Args: 

814 u: Data you want to transform 

815 axes (tuple): Axes you want to transform over 

816 

817 Returns: 

818 transformed data 

819 """ 

820 axes = axes if axes else tuple(i for i in range(u.ndim)) 

821 kwargs['s'] = shape 

822 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs) 

823 return plan(u, *args, axes=axes, **kwargs) 

824 

825 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

826 """ 

827 Inverse FFT. 

828 

829 Args: 

830 u: Data you want to transform 

831 axes (tuple): Axes over which to transform 

832 

833 Returns: 

834 transformed data 

835 """ 

836 axes = axes if axes else tuple(i for i in range(u.ndim)) 

837 kwargs['s'] = shape 

838 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs) 

839 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes]) 

840 

841 def get_BC(self, kind): 

842 """ 

843 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

844 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

845 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

846 

847 Args: 

848 kind ('integral' or 'nyquist'): Kind of BC 

849 

850 Returns: 

851 self.xp.ndarray: Boundary condition row 

852 """ 

853 if kind.lower() == 'integral': 

854 return self.get_integ_BC_row() 

855 elif kind.lower() == 'nyquist': 

856 assert ( 

857 self.N % 2 == 0 

858 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

859 BC = self.xp.zeros(self.N) 

860 BC[self.get_Nyquist_mode_index()] = 1 

861 return BC 

862 else: 

863 return super().get_BC(kind) 

864 

865 def get_Nyquist_mode_index(self): 

866 """ 

867 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

868 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

869 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

870 after. 

871 

872 Returns: 

873 int: Index of the Nyquist mode 

874 """ 

875 k = self.get_wavenumbers() 

876 Nyquist_mode = min(k) 

877 return self.xp.where(k == Nyquist_mode)[0][0] 

878 

879 def get_integ_BC_row(self): 

880 """ 

881 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

882 """ 

883 me = self.xp.zeros(self.N) 

884 me[0] = self.L / self.N 

885 return me 

886 

887 

888class SpectralHelper: 

889 """ 

890 This class has three functions: 

891 - Easily assemble matrices containing multiple equations 

892 - Direct product of 1D bases to solve problems in more dimensions 

893 - Distribute the FFTs to facilitate concurrency. 

894 

895 Attributes: 

896 comm (mpi4py.Intracomm): MPI communicator 

897 debug (bool): Perform additional checks at extra computational cost 

898 useGPU (bool): Whether to use GPUs 

899 axes (list): List of 1D bases 

900 components (list): List of strings of the names of components in the equations 

901 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

902 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

903 BCs (sparse matrix): Matrix containing only the BCs 

904 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

905 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

906 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

907 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

908 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

909 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

910 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

911 init (tuple): This is the same `init` that is used throughout the problem classes 

912 init_forward (tuple): This is the equivalent of `init` in spectral space 

913 """ 

914 

915 xp = np 

916 fft_lib = scipy.fft 

917 sparse_lib = scipy.sparse 

918 linalg = scipy.sparse.linalg 

919 dtype = mesh 

920 fft_backend = 'scipy' 

921 fft_comm_backend = 'MPI' 

922 

923 @classmethod 

924 def setup_GPU(cls): 

925 """switch to GPU modules""" 

926 import cupy as cp 

927 import cupyx.scipy.sparse as sparse_lib 

928 import cupyx.scipy.sparse.linalg as linalg 

929 import cupyx.scipy.fft as fft_lib 

930 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

931 

932 cls.xp = cp 

933 cls.sparse_lib = sparse_lib 

934 cls.linalg = linalg 

935 

936 cls.fft_lib = fft_lib 

937 cls.fft_backend = 'cupyx-scipy' 

938 cls.fft_comm_backend = 'NCCL' 

939 

940 cls.dtype = cupy_mesh 

941 

942 @classmethod 

943 def setup_CPU(cls, useFFTW=False): 

944 """switch to CPU modules""" 

945 

946 cls.xp = np 

947 cls.sparse_lib = scipy.sparse 

948 cls.linalg = scipy.sparse.linalg 

949 

950 if useFFTW: 

951 from mpi4py_fft import fftw 

952 

953 cls.fft_backend = 'fftw' 

954 cls.fft_lib = fftw 

955 else: 

956 cls.fft_backend = 'scipy' 

957 cls.fft_lib = scipy.fft 

958 

959 cls.fft_comm_backend = 'MPI' 

960 cls.dtype = mesh 

961 

962 def __init__(self, comm=None, useGPU=False, debug=False): 

963 """ 

964 Constructor 

965 

966 Args: 

967 comm (mpi4py.Intracomm): MPI communicator 

968 useGPU (bool): Whether to use GPUs 

969 debug (bool): Perform additional checks at extra computational cost 

970 """ 

971 self.comm = comm 

972 self.debug = debug 

973 self.useGPU = useGPU 

974 

975 if useGPU: 

976 self.setup_GPU() 

977 else: 

978 self.setup_CPU() 

979 

980 self.axes = [] 

981 self.components = [] 

982 

983 self.full_BCs = [] 

984 self.BC_mat = None 

985 self.BCs = None 

986 

987 self.fft_cache = {} 

988 self.fft_dealias_shape_cache = {} 

989 

990 self.logger = logging.getLogger(name='Spectral Discretization') 

991 if debug: 

992 self.logger.setLevel(logging.DEBUG) 

993 

994 @property 

995 def u_init(self): 

996 """ 

997 Get empty data container in physical space 

998 """ 

999 return self.dtype(self.init) 

1000 

1001 @property 

1002 def u_init_forward(self): 

1003 """ 

1004 Get empty data container in spectral space 

1005 """ 

1006 return self.dtype(self.init_forward) 

1007 

1008 @property 

1009 def u_init_physical(self): 

1010 """ 

1011 Get empty data container in physical space 

1012 """ 

1013 return self.dtype(self.init_physical) 

1014 

1015 @property 

1016 def shape(self): 

1017 """ 

1018 Get shape of individual solution component 

1019 """ 

1020 return self.init[0][1:] 

1021 

1022 @property 

1023 def ndim(self): 

1024 return len(self.axes) 

1025 

1026 @property 

1027 def ncomponents(self): 

1028 return len(self.components) 

1029 

1030 @property 

1031 def V(self): 

1032 """ 

1033 Get domain volume 

1034 """ 

1035 return np.prod([me.L for me in self.axes]) 

1036 

1037 def add_axis(self, base, *args, **kwargs): 

1038 """ 

1039 Add an axis to the domain by deciding on suitable 1D base. 

1040 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

1041 bases for possible arguments. 

1042 

1043 Args: 

1044 base (str): 1D spectral method 

1045 """ 

1046 kwargs['useGPU'] = self.useGPU 

1047 

1048 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

1049 self.axes.append(ChebychevHelper(*args, **kwargs)) 

1050 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

1051 self.axes.append(FFTHelper(*args, **kwargs)) 

1052 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

1053 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

1054 else: 

1055 raise NotImplementedError(f'{base=!r} is not implemented!') 

1056 self.axes[-1].xp = self.xp 

1057 self.axes[-1].sparse_lib = self.sparse_lib 

1058 

1059 def add_component(self, name): 

1060 """ 

1061 Add solution component(s). 

1062 

1063 Args: 

1064 name (str or list of strings): Name(s) of component(s) 

1065 """ 

1066 if type(name) in [list, tuple]: 

1067 for me in name: 

1068 self.add_component(me) 

1069 elif type(name) in [str]: 

1070 if name in self.components: 

1071 raise Exception(f'{name=!r} is already added to this problem!') 

1072 self.components.append(name) 

1073 else: 

1074 raise NotImplementedError 

1075 

1076 def index(self, name): 

1077 """ 

1078 Get the index of component `name`. 

1079 

1080 Args: 

1081 name (str or list of strings): Name(s) of component(s) 

1082 

1083 Returns: 

1084 int: Index of the component 

1085 """ 

1086 if type(name) in [str, int]: 

1087 return self.components.index(name) 

1088 elif type(name) in [list, tuple]: 

1089 return (self.index(me) for me in name) 

1090 else: 

1091 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

1092 

1093 def get_empty_operator_matrix(self, diag=False): 

1094 """ 

1095 Return a matrix of operators to be filled with the connections between the solution components. 

1096 

1097 Args: 

1098 diag (bool): Whether operator is block-diagonal 

1099 

1100 Returns: 

1101 list containing sparse zeros 

1102 """ 

1103 S = len(self.components) 

1104 O = self.get_Id() * 0 

1105 if diag: 

1106 return [O for _ in range(S)] 

1107 else: 

1108 return [[O for _ in range(S)] for _ in range(S)] 

1109 

1110 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

1111 """ 

1112 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

1113 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

1114 Typically, the last line that does not contain a BC is the best choice. 

1115 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1116 

1117 Args: 

1118 axis (int): Axis you want to add the BC to 

1119 kind (str): kind of BC, e.g. Dirichlet 

1120 line (int): Line you want the BC to go in 

1121 scalar (bool): Put the BC in all space positions in the other direction 

1122 

1123 Returns: 

1124 sparse matrix containing the BC 

1125 """ 

1126 sp = scipy.sparse 

1127 

1128 base = self.axes[axis] 

1129 

1130 BC = sp.eye(base.N).tolil() * 0 

1131 if self.useGPU: 

1132 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1133 else: 

1134 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1135 

1136 ndim = len(self.axes) 

1137 if ndim == 1: 

1138 return self.sparse_lib.csc_matrix(BC) 

1139 elif ndim == 2: 

1140 axis2 = (axis + 1) % ndim 

1141 

1142 if scalar: 

1143 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1144 else: 

1145 _Id = self.axes[axis2].get_Id() 

1146 

1147 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1148 

1149 mats = [ 

1150 None, 

1151 ] * ndim 

1152 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1153 mats[axis2] = Id 

1154 return self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) 

1155 if ndim == 3: 

1156 mats = [ 

1157 None, 

1158 ] * ndim 

1159 

1160 for ax in range(ndim): 

1161 if ax == axis: 

1162 continue 

1163 

1164 if scalar: 

1165 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1))) 

1166 else: 

1167 _Id = self.axes[ax].get_Id() 

1168 

1169 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax) 

1170 

1171 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1172 

1173 return self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) 

1174 else: 

1175 raise NotImplementedError( 

1176 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!' 

1177 ) 

1178 

1179 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1180 """ 

1181 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1182 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1183 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1184 

1185 Args: 

1186 component (str): Name of the component the BC should act on 

1187 equation (str): Name of the equation for the component you want to put the BC in 

1188 axis (int): Axis you want to add the BC to 

1189 kind (str): kind of BC, e.g. Dirichlet 

1190 v: Value of the BC 

1191 line (int): Line you want the BC to go in 

1192 scalar (bool): Put the BC in all space positions in the other direction 

1193 """ 

1194 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1195 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1196 

1197 if scalar: 

1198 slices = [self.index(equation)] + [ 

1199 0, 

1200 ] * self.ndim 

1201 slices[axis + 1] = line 

1202 else: 

1203 slices = ( 

1204 [self.index(equation)] 

1205 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1206 + [line] 

1207 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1208 ) 

1209 N = self.axes[axis].N 

1210 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]: 

1211 self.BC_rhs_mask[(*slices,)] = False 

1212 

1213 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1214 """ 

1215 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1216 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1217 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1218 

1219 Args: 

1220 component (str): Name of the component the BC should act on 

1221 equation (str): Name of the equation for the component you want to put the BC in 

1222 axis (int): Axis you want to add the BC to 

1223 kind (str): kind of BC, e.g. Dirichlet 

1224 v: Value of the BC 

1225 line (int): Line you want the BC to go in 

1226 scalar (bool): Put the BC in all space positions in the other direction 

1227 """ 

1228 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1229 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1230 self.full_BCs += [ 

1231 { 

1232 'component': component, 

1233 'equation': equation, 

1234 'axis': axis, 

1235 'kind': kind, 

1236 'v': v, 

1237 'line': line, 

1238 'scalar': scalar, 

1239 **kwargs, 

1240 } 

1241 ] 

1242 

1243 if scalar: 

1244 slices = [self.index(equation)] + [ 

1245 0, 

1246 ] * self.ndim 

1247 slices[axis + 1] = line 

1248 if self.comm: 

1249 if self.comm.rank == 0: 

1250 self.BC_rhs_mask[(*slices,)] = True 

1251 else: 

1252 self.BC_rhs_mask[(*slices,)] = True 

1253 else: 

1254 slices = [self.index(equation), *self.global_slice(True)] 

1255 N = self.axes[axis].N 

1256 if (N + line) % N in self.get_indices(True)[axis]: 

1257 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1258 self.BC_rhs_mask[(*slices,)] = True 

1259 

1260 def setup_BCs(self): 

1261 """ 

1262 Convert the list of lists of BCs to the boundary condition operator. 

1263 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1264 condition. This method sets up a suitable sparse matrix to do this. 

1265 """ 

1266 sp = self.sparse_lib 

1267 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1268 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1269 

1270 diags = self.xp.ones(self.BCs.shape[0]) 

1271 diags[self.BC_zero_index] = 0 

1272 self.BC_line_zero_matrix = sp.diags(diags) 

1273 

1274 # prepare BCs in spectral space to easily add to the RHS 

1275 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1276 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1277 

1278 def check_BCs(self, u): 

1279 """ 

1280 Check that the solution satisfies the boundary conditions 

1281 

1282 Args: 

1283 u: The solution you want to check 

1284 """ 

1285 assert self.ndim < 3 

1286 for axis in range(self.ndim): 

1287 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1288 

1289 if len(BCs) > 0: 

1290 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1291 for BC in BCs: 

1292 kwargs = { 

1293 key: value 

1294 for key, value in BC.items() 

1295 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1296 } 

1297 

1298 if axis == 0: 

1299 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1300 elif axis == 1: 

1301 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1302 want = BC['v'] 

1303 assert self.xp.allclose( 

1304 get, want 

1305 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1306 

1307 def put_BCs_in_matrix(self, A): 

1308 """ 

1309 Put the boundary conditions in a matrix by replacing rows with BCs. 

1310 """ 

1311 return self.BC_line_zero_matrix @ A + self.BCs 

1312 

1313 def put_BCs_in_rhs_hat(self, rhs_hat): 

1314 """ 

1315 Put the BCs in the right hand side in spectral space for solving. 

1316 This function needs no transforms and caches a mask for faster subsequent use. 

1317 

1318 Args: 

1319 rhs_hat: Right hand side in spectral space 

1320 

1321 Returns: 

1322 rhs in spectral space with BCs 

1323 """ 

1324 if not hasattr(self, '_rhs_hat_zero_mask'): 

1325 """ 

1326 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1327 by the boundary conditions. The mask is then cached. 

1328 """ 

1329 self._rhs_hat_zero_mask = self.newDistArray().astype(bool) 

1330 

1331 for axis in range(self.ndim): 

1332 for bc in self.full_BCs: 

1333 if axis == bc['axis']: 

1334 slices = [self.index(bc['equation']), *self.global_slice(True)] 

1335 N = self.axes[axis].N 

1336 line = bc['line'] 

1337 if (N + line) % N in self.get_indices(True)[axis]: 

1338 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1339 self._rhs_hat_zero_mask[(*slices,)] = True 

1340 

1341 rhs_hat[self._rhs_hat_zero_mask] = 0 

1342 return rhs_hat + self.rhs_BCs_hat 

1343 

1344 def put_BCs_in_rhs(self, rhs): 

1345 """ 

1346 Put the BCs in the right hand side for solving. 

1347 This function will transform along each axis individually and add all BCs in that axis. 

1348 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1349 

1350 Args: 

1351 rhs: Right hand side in physical space 

1352 

1353 Returns: 

1354 rhs in physical space with BCs 

1355 """ 

1356 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1357 

1358 ndim = self.ndim 

1359 

1360 for axis in range(ndim): 

1361 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1362 

1363 for bc in self.full_BCs: 

1364 

1365 if axis == bc['axis']: 

1366 _slice = [self.index(bc['equation']), *self.global_slice(True)] 

1367 

1368 N = self.axes[axis].N 

1369 line = bc['line'] 

1370 if (N + line) % N in self.get_indices(True)[axis]: 

1371 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1372 _rhs_hat[(*_slice,)] = bc['v'] 

1373 

1374 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1375 

1376 return rhs 

1377 

1378 def add_equation_lhs(self, A, equation, relations, diag=False): 

1379 """ 

1380 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1381 that you will convert to an operator later. 

1382 

1383 Example: 

1384 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1385 preconditioning: 

1386 

1387 .. code-block:: python 

1388 helper = SpectralHelper() 

1389 

1390 helper.add_axis(base='chebychev', N=8) 

1391 helper.add_component(['u', 'ux']) 

1392 helper.setup_fft() 

1393 

1394 I = helper.get_Id() 

1395 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1396 T2U = helper.get_basis_change_matrix('T2U') 

1397 

1398 L_lhs = { 

1399 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1400 'u': {'ux': -(T2U @ Dx)}, 

1401 } 

1402 

1403 operator = helper.get_empty_operator_matrix() 

1404 for line, equation in L_lhs.items(): 

1405 helper.add_equation_lhs(operator, line, equation) 

1406 

1407 L = helper.convert_operator_matrix_to_operator(operator) 

1408 

1409 Args: 

1410 A (list of lists of sparse matrices): The operator to be 

1411 equation (str): The equation of the component you want this in 

1412 relations: (dict): Relations between quantities 

1413 diag (bool): Whether operator is block-diagonal 

1414 """ 

1415 for k, v in relations.items(): 

1416 if diag: 

1417 assert k == equation, 'You are trying to put a non-diagonal equation into a diagonal operator' 

1418 A[self.index(equation)] = v 

1419 else: 

1420 A[self.index(equation)][self.index(k)] = v 

1421 

1422 def convert_operator_matrix_to_operator(self, M, diag=False): 

1423 """ 

1424 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1425 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1426 

1427 Args: 

1428 M (list of lists of sparse matrices): The operator to be 

1429 

1430 Returns: 

1431 sparse linear operator 

1432 """ 

1433 if len(self.components) == 1: 

1434 if diag: 

1435 return M[0] 

1436 else: 

1437 return M[0][0] 

1438 elif diag: 

1439 return self.sparse_lib.block_diag(M, format='csc') 

1440 else: 

1441 return self.sparse_lib.block_array(M, format='csc') 

1442 

1443 def get_wavenumbers(self): 

1444 """ 

1445 Get grid in spectral space 

1446 """ 

1447 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))] 

1448 return self.xp.meshgrid(*grids, indexing='ij') 

1449 

1450 def get_grid(self, forward_output=False): 

1451 """ 

1452 Get grid in physical space 

1453 """ 

1454 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1455 return self.xp.meshgrid(*grids, indexing='ij') 

1456 

1457 def get_indices(self, forward_output=True): 

1458 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1459 

1460 @cache 

1461 def get_pfft(self, axes=None, padding=None, grid=None): 

1462 if self.ndim == 1 or self.comm is None: 

1463 return None 

1464 from mpi4py_fft import PFFT 

1465 

1466 axes = tuple(i for i in range(self.ndim)) if axes is None else axes 

1467 padding = list(padding if padding else [1.0 for _ in range(self.ndim)]) 

1468 

1469 def no_transform(u, *args, **kwargs): 

1470 return u 

1471 

1472 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)} 

1473 for i in axes: 

1474 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform) 

1475 

1476 # "transform" all axes to ensure consistent shapes. 

1477 # Transform non-distributable axes last to ensure they are aligned 

1478 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes)) 

1479 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted( 

1480 [axis for axis in _axes if self.axes[axis].distributable] 

1481 + [axis for axis in range(self.ndim) if axis not in _axes] 

1482 ) 

1483 

1484 pfft = PFFT( 

1485 comm=self.comm, 

1486 shape=self.global_shape[1:], 

1487 axes=_axes, # TODO: control the order of the transforms better 

1488 dtype='D', 

1489 collapse=False, 

1490 backend=self.fft_backend, 

1491 comm_backend=self.fft_comm_backend, 

1492 padding=padding, 

1493 transforms=transforms, 

1494 grid=grid, 

1495 ) 

1496 return pfft 

1497 

1498 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1499 """ 

1500 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1501 

1502 Args: 

1503 axes (tuple): Axes you want to transform over 

1504 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1505 padding (tuple): Padding for dealiasing 

1506 shape (tuple): Shape of the transform 

1507 

1508 Returns: 

1509 transform 

1510 """ 

1511 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1512 shape = self.global_shape[1:] if shape is None else shape 

1513 padding = ( 

1514 [ 

1515 1, 

1516 ] 

1517 * self.ndim 

1518 if padding is None 

1519 else padding 

1520 ) 

1521 key = (axes, direction, tuple(padding), tuple(shape)) 

1522 

1523 if key not in self.fft_cache.keys(): 

1524 if self.comm is None: 

1525 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1526 

1527 if direction == 'forward': 

1528 self.fft_cache[key] = self.xp.fft.fftn 

1529 elif direction == 'backward': 

1530 self.fft_cache[key] = self.xp.fft.ifftn 

1531 elif direction == 'object': 

1532 self.fft_cache[key] = None 

1533 else: 

1534 if direction == 'object': 

1535 from mpi4py_fft import PFFT 

1536 

1537 _fft = PFFT( 

1538 comm=self.comm, 

1539 shape=shape, 

1540 axes=sorted(axes), 

1541 dtype='D', 

1542 collapse=False, 

1543 backend=self.fft_backend, 

1544 comm_backend=self.fft_comm_backend, 

1545 padding=padding, 

1546 ) 

1547 else: 

1548 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1549 

1550 if direction == 'forward': 

1551 self.fft_cache[key] = _fft.forward 

1552 elif direction == 'backward': 

1553 self.fft_cache[key] = _fft.backward 

1554 elif direction == 'object': 

1555 self.fft_cache[key] = _fft 

1556 

1557 return self.fft_cache[key] 

1558 

1559 def local_slice(self, forward_output=True): 

1560 if self.fft_obj: 

1561 return self.get_pfft().local_slice(forward_output=forward_output) 

1562 else: 

1563 return [slice(0, me.N) for me in self.axes] 

1564 

1565 def global_slice(self, forward_output=True): 

1566 if self.fft_obj: 

1567 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)] 

1568 else: 

1569 return self.local_slice(forward_output=forward_output) 

1570 

1571 def setup_fft(self, real_spectral_coefficients=False): 

1572 """ 

1573 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1574 This must also be called before setting up any BCs. 

1575 

1576 Args: 

1577 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1578 """ 

1579 if len(self.components) == 0: 

1580 self.add_component('u') 

1581 

1582 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1583 

1584 axes = tuple(i for i in range(len(self.axes))) 

1585 self.fft_obj = self.get_pfft(axes=axes) 

1586 

1587 self.init = ( 

1588 np.empty(shape=self.global_shape)[ 

1589 ( 

1590 ..., 

1591 *self.local_slice(False), 

1592 ) 

1593 ].shape, 

1594 self.comm, 

1595 np.dtype('float'), 

1596 ) 

1597 self.init_physical = ( 

1598 np.empty(shape=self.global_shape)[ 

1599 ( 

1600 ..., 

1601 *self.local_slice(False), 

1602 ) 

1603 ].shape, 

1604 self.comm, 

1605 np.dtype('float'), 

1606 ) 

1607 self.init_forward = ( 

1608 np.empty(shape=self.global_shape)[ 

1609 ( 

1610 ..., 

1611 *self.local_slice(True), 

1612 ) 

1613 ].shape, 

1614 self.comm, 

1615 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1616 ) 

1617 

1618 self.BC_mat = self.get_empty_operator_matrix() 

1619 self.BC_rhs_mask = self.newDistArray().astype(bool) 

1620 

1621 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False): 

1622 """ 

1623 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but 

1624 takes care of all the solution components in the tensor. 

1625 """ 

1626 if self.comm is None: 

1627 return self.xp.zeros(self.init[0], dtype=self.init[2]) 

1628 from mpi4py_fft.distarray import DistArray 

1629 

1630 pfft = pfft if pfft else self.get_pfft() 

1631 if pfft is None: 

1632 if forward_output: 

1633 return self.u_init_forward 

1634 else: 

1635 return self.u_init 

1636 

1637 global_shape = pfft.global_shape(forward_output) 

1638 p0 = pfft.pencil[forward_output] 

1639 if forward_output is True: 

1640 dtype = pfft.forward.output_array.dtype 

1641 else: 

1642 dtype = pfft.forward.input_array.dtype 

1643 global_shape = (self.ncomponents,) * rank + global_shape 

1644 

1645 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: 

1646 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls 

1647 else: 

1648 darraycls = DistArray 

1649 

1650 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank) 

1651 return z.v if view else z 

1652 

1653 def infer_alignment(self, u, forward_output, padding=None, **kwargs): 

1654 if self.comm is None: 

1655 return [0] 

1656 

1657 def _alignment(pfft): 

1658 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1659 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]] 

1660 return _aligned_axes 

1661 

1662 if padding is None: 

1663 pfft = self.get_pfft(**kwargs) 

1664 aligned_axes = _alignment(pfft) 

1665 else: 

1666 if self.ndim == 2: 

1667 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)] 

1668 elif self.ndim == 3: 

1669 padding_options = [ 

1670 (1.0, 1.0, padding[2]), 

1671 (1.0, padding[1], 1.0), 

1672 (padding[0], 1.0, 1.0), 

1673 (1.0, padding[1], padding[2]), 

1674 (padding[0], 1.0, padding[2]), 

1675 (padding[0], padding[1], 1.0), 

1676 padding, 

1677 (1.0, 1.0, 1.0), 

1678 ] 

1679 else: 

1680 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!') 

1681 for _padding in padding_options: 

1682 pfft = self.get_pfft(padding=_padding, **kwargs) 

1683 aligned_axes = _alignment(pfft) 

1684 if len(aligned_axes) > 0: 

1685 self.logger.debug( 

1686 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}' 

1687 ) 

1688 break 

1689 

1690 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!' 

1691 return aligned_axes 

1692 

1693 def redistribute(self, u, axis, forward_output, **kwargs): 

1694 if self.comm is None: 

1695 return u 

1696 

1697 pfft = self.get_pfft(**kwargs) 

1698 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1699 

1700 if 'Dist' in type(u).__name__ and False: 

1701 try: 

1702 u.redistribute(out=_arr) 

1703 return _arr 

1704 except AssertionError: 

1705 pass 

1706 

1707 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs) 

1708 for alignment in u_alignment: 

1709 _arr = _arr.redistribute(alignment) 

1710 if _arr.shape == u.shape: 

1711 _arr[...] = u 

1712 return _arr.redistribute(axis) 

1713 

1714 raise Exception( 

1715 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}' 

1716 ) 

1717 

1718 def transform(self, u, *args, axes=None, padding=None, **kwargs): 

1719 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1720 

1721 if pfft is None: 

1722 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1723 u_hat = u.copy() 

1724 for i in axes: 

1725 _axis = 1 + i if i >= 0 else i 

1726 u_hat = self.axes[i].transform(u_hat, axes=(_axis,)) 

1727 return u_hat 

1728 

1729 _in = self.newDistArray(pfft, forward_output=False, rank=1) 

1730 _out = self.newDistArray(pfft, forward_output=True, rank=1) 

1731 

1732 if _in.shape == u.shape: 

1733 _in[...] = u 

1734 else: 

1735 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs) 

1736 

1737 for i in range(self.ncomponents): 

1738 pfft.forward(_in[i], _out[i], normalize=False) 

1739 

1740 if padding is not None: 

1741 _out /= np.prod(padding) 

1742 return _out 

1743 

1744 def itransform(self, u, *args, axes=None, padding=None, **kwargs): 

1745 if padding is not None: 

1746 assert all( 

1747 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim) 

1748 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer' 

1749 

1750 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1751 if pfft is None: 

1752 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1753 u_hat = u.copy() 

1754 for i in axes: 

1755 _axis = 1 + i if i >= 0 else i 

1756 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,)) 

1757 return u_hat 

1758 

1759 _in = self.newDistArray(pfft, forward_output=True, rank=1) 

1760 _out = self.newDistArray(pfft, forward_output=False, rank=1) 

1761 

1762 if _in.shape == u.shape: 

1763 _in[...] = u 

1764 else: 

1765 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs) 

1766 

1767 for i in range(self.ncomponents): 

1768 pfft.backward(_in[i], _out[i], normalize=True) 

1769 

1770 if padding is not None: 

1771 _out *= np.prod(padding) 

1772 return _out 

1773 

1774 def get_local_slice_of_1D_matrix(self, M, axis): 

1775 """ 

1776 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1777 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can 

1778 use this method to get the part corresponding to the modes carried by this rank. 

1779 

1780 Args: 

1781 M (sparse matrix): Global 1D matrix you want to get the local version of 

1782 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. 

1783 

1784 Returns: 

1785 sparse local matrix 

1786 """ 

1787 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]] 

1788 

1789 def expand_matrix_ND(self, matrix, aligned): 

1790 sp = self.sparse_lib 

1791 axes = np.delete(np.arange(self.ndim), aligned) 

1792 ndim = len(axes) + 1 

1793 

1794 if ndim == 1: 

1795 return matrix 

1796 elif ndim == 2: 

1797 axis = axes[0] 

1798 I1D = sp.eye(self.axes[axis].N) 

1799 

1800 mats = [None] * ndim 

1801 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1802 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1803 

1804 return sp.kron(*mats) 

1805 elif ndim == 3: 

1806 

1807 mats = [None] * ndim 

1808 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1809 for axis in axes: 

1810 I1D = sp.eye(self.axes[axis].N) 

1811 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1812 

1813 return sp.kron(mats[0], sp.kron(*mats[1:])) 

1814 

1815 else: 

1816 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') 

1817 

1818 def get_filter_matrix(self, axis, **kwargs): 

1819 """ 

1820 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1821 admissible. 

1822 

1823 Returns: 

1824 sparse bandpass matrix 

1825 """ 

1826 if self.ndim == 1: 

1827 return self.axes[0].get_filter_matrix(**kwargs) 

1828 

1829 mats = [base.get_Id() for base in self.axes] 

1830 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1831 return self.sparse_lib.kron(*mats) 

1832 

1833 def get_differentiation_matrix(self, axes, **kwargs): 

1834 """ 

1835 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1836 

1837 Args: 

1838 axes (tuple): Axes along which to differentiate. 

1839 

1840 Returns: 

1841 sparse differentiation matrix 

1842 """ 

1843 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) 

1844 for axis in axes[1:]: 

1845 _D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1846 D = D @ self.expand_matrix_ND(_D, axis) 

1847 

1848 return D 

1849 

1850 def get_integration_matrix(self, axes): 

1851 """ 

1852 Get integration matrix to integrate along specified axis. 

1853 

1854 Args: 

1855 axes (tuple): Axes along which to integrate over. 

1856 

1857 Returns: 

1858 sparse integration matrix 

1859 """ 

1860 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) 

1861 for axis in axes[1:]: 

1862 _S = self.axes[axis].get_integration_matrix() 

1863 S = S @ self.expand_matrix_ND(_S, axis) 

1864 

1865 return S 

1866 

1867 def get_Id(self): 

1868 """ 

1869 Get identity matrix 

1870 

1871 Returns: 

1872 sparse identity matrix 

1873 """ 

1874 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) 

1875 for axis in range(1, self.ndim): 

1876 _I = self.axes[axis].get_Id() 

1877 I = I @ self.expand_matrix_ND(_I, axis) 

1878 return I 

1879 

1880 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1881 """ 

1882 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1883 

1884 Args: 

1885 axis (int): Axis you discretized with Chebychev 

1886 

1887 Returns: 

1888 sparse matrix 

1889 """ 

1890 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1891 return self.expand_matrix_ND(C1D, axis) 

1892 

1893 def get_basis_change_matrix(self, axes=None, **kwargs): 

1894 """ 

1895 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

1896 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

1897 

1898 Args: 

1899 axes (tuple): Axes along which to change basis. 

1900 

1901 Returns: 

1902 sparse basis change matrix 

1903 """ 

1904 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1905 

1906 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) 

1907 for axis in axes[1:]: 

1908 _C = self.axes[axis].get_basis_change_matrix(**kwargs) 

1909 C = C @ self.expand_matrix_ND(_C, axis) 

1910 

1911 return C