Coverage for pySDC/helpers/spectral_helper.py: 89%

771 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-18 13:09 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5from functools import partial, wraps 

6import logging 

7 

8 

9def cache(func): 

10 """ 

11 Decorator for caching return values of functions. 

12 This is very similar to `functools.cache`, but without the memory leaks (see 

13 https://docs.astral.sh/ruff/rules/cached-instance-method/). 

14 

15 Example: 

16 

17 .. code-block:: python 

18 

19 num_calls = 0 

20 

21 @cache 

22 def increment(x): 

23 num_calls += 1 

24 return x + 1 

25 

26 increment(0) # returns 1, num_calls = 1 

27 increment(1) # returns 2, num_calls = 2 

28 increment(0) # returns 1, num_calls = 2 

29 

30 

31 Args: 

32 func (function): The function you want to cache the return value of 

33 

34 Returns: 

35 return value of func 

36 """ 

37 attr_cache = f"_{func.__name__}_cache" 

38 

39 @wraps(func) 

40 def wrapper(self, *args, **kwargs): 

41 if not hasattr(self, attr_cache): 

42 setattr(self, attr_cache, {}) 

43 

44 cache = getattr(self, attr_cache) 

45 

46 key = (args, frozenset(kwargs.items())) 

47 if key in cache: 

48 return cache[key] 

49 result = func(self, *args, **kwargs) 

50 cache[key] = result 

51 return result 

52 

53 return wrapper 

54 

55 

56class SpectralHelper1D: 

57 """ 

58 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

59 all bases need to have. 

60 

61 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

62 the code for GPUs. 

63 

64 Attributes: 

65 N (int): Resolution 

66 x0 (float): Coordinate of left boundary 

67 x1 (float): Coordinate of right boundary 

68 L (float): Length of the domain 

69 useGPU (bool): Whether to use GPUs 

70 

71 """ 

72 

73 fft_lib = scipy.fft 

74 sparse_lib = scipy.sparse 

75 linalg = scipy.sparse.linalg 

76 xp = np 

77 distributable = False 

78 

79 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False): 

80 """ 

81 Constructor 

82 

83 Args: 

84 N (int): Resolution 

85 x0 (float): Coordinate of left boundary 

86 x1 (float): Coordinate of right boundary 

87 useGPU (bool): Whether to use GPUs 

88 useFFTW (bool): Whether to use FFTW for the transforms 

89 """ 

90 self.N = N 

91 self.x0 = x0 

92 self.x1 = x1 

93 self.L = x1 - x0 

94 self.useGPU = useGPU 

95 self.plans = {} 

96 self.logger = logging.getLogger(name=type(self).__name__) 

97 

98 if useGPU: 

99 self.setup_GPU() 

100 else: 

101 self.setup_CPU(useFFTW=useFFTW) 

102 

103 if useGPU and useFFTW: 

104 raise ValueError('Please run either on GPUs or with FFTW, not both!') 

105 

106 @classmethod 

107 def setup_GPU(cls): 

108 """switch to GPU modules""" 

109 import cupy as cp 

110 import cupyx.scipy.sparse as sparse_lib 

111 import cupyx.scipy.sparse.linalg as linalg 

112 import cupyx.scipy.fft as fft_lib 

113 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

114 

115 cls.xp = cp 

116 cls.sparse_lib = sparse_lib 

117 cls.linalg = linalg 

118 cls.fft_lib = fft_lib 

119 

120 @classmethod 

121 def setup_CPU(cls, useFFTW=False): 

122 """switch to CPU modules""" 

123 

124 cls.xp = np 

125 cls.sparse_lib = scipy.sparse 

126 cls.linalg = scipy.sparse.linalg 

127 

128 if useFFTW: 

129 from mpi4py_fft import fftw 

130 

131 cls.fft_backend = 'fftw' 

132 cls.fft_lib = fftw 

133 else: 

134 cls.fft_backend = 'scipy' 

135 cls.fft_lib = scipy.fft 

136 

137 cls.fft_comm_backend = 'MPI' 

138 cls.dtype = mesh 

139 

140 def get_Id(self): 

141 """ 

142 Get identity matrix 

143 

144 Returns: 

145 sparse diagonal identity matrix 

146 """ 

147 return self.sparse_lib.eye(self.N) 

148 

149 def get_zero(self): 

150 """ 

151 Get a matrix with all zeros of the correct size. 

152 

153 Returns: 

154 sparse matrix with zeros everywhere 

155 """ 

156 return 0 * self.get_Id() 

157 

158 def get_differentiation_matrix(self): 

159 raise NotImplementedError() 

160 

161 def get_integration_matrix(self): 

162 raise NotImplementedError() 

163 

164 def get_wavenumbers(self): 

165 """ 

166 Get the grid in spectral space 

167 """ 

168 raise NotImplementedError 

169 

170 def get_empty_operator_matrix(self, S, O): 

171 """ 

172 Return a matrix of operators to be filled with the connections between the solution components. 

173 

174 Args: 

175 S (int): Number of components in the solution 

176 O (sparse matrix): Zero matrix used for initialization 

177 

178 Returns: 

179 list of lists containing sparse zeros 

180 """ 

181 return [[O for _ in range(S)] for _ in range(S)] 

182 

183 def get_basis_change_matrix(self, *args, **kwargs): 

184 """ 

185 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

186 between the various bases. 

187 

188 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

189 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

190 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

191 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

192 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

193 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

194 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

195 

196 Returns: 

197 sparse bases change matrix 

198 """ 

199 return self.sparse_lib.eye(self.N) 

200 

201 def get_BC(self, kind): 

202 """ 

203 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

204 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

205 the matrix are replaced by the boundary conditions as obtained by this method. 

206 

207 Args: 

208 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

209 individual 1D bases for what is implemented 

210 

211 Returns: 

212 self.xp.array: Boundary condition 

213 """ 

214 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

215 

216 def get_filter_matrix(self, kmin=0, kmax=None): 

217 """ 

218 Get a bandpass filter. 

219 

220 Args: 

221 kmin (int): Lower limit of the bandpass filter 

222 kmax (int): Upper limit of the bandpass filter 

223 

224 Returns: 

225 sparse matrix 

226 """ 

227 

228 k = abs(self.get_wavenumbers()) 

229 

230 kmax = max(k) if kmax is None else kmax 

231 

232 mask = self.xp.logical_or(k >= kmax, k < kmin) 

233 

234 if self.useGPU: 

235 Id = self.get_Id().get() 

236 else: 

237 Id = self.get_Id() 

238 F = Id.tolil() 

239 F[:, mask] = 0 

240 return F.tocsc() 

241 

242 def get_1dgrid(self): 

243 """ 

244 Get the grid in physical space 

245 

246 Returns: 

247 self.xp.array: Grid 

248 """ 

249 raise NotImplementedError 

250 

251 

252class ChebychevHelper(SpectralHelper1D): 

253 """ 

254 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

255 between physical and spectral space by discrete cosine transform. 

256 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

257 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

258 be formulated in first order formulation. 

259 

260 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068). 

261 """ 

262 

263 def __init__(self, *args, x0=-1, x1=1, **kwargs): 

264 """ 

265 Constructor. 

266 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

267 may choose to run on GPUs via the `useGPU` argument. 

268 

269 Args: 

270 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

271 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

272 """ 

273 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

274 self.lin_trf_fac = (x1 - x0) / 2 

275 self.lin_trf_off = (x1 + x0) / 2 

276 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

277 

278 self.norm = self.get_norm() 

279 

280 def get_1dgrid(self): 

281 ''' 

282 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

283 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

284 need to do an affine transformation before any Chebychev business. 

285 

286 Returns: 

287 numpy.ndarray: 1D grid 

288 ''' 

289 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

290 

291 def get_wavenumbers(self): 

292 """Get the domain in spectral space""" 

293 return self.xp.arange(self.N) 

294 

295 @cache 

296 def get_conv(self, name, N=None): 

297 ''' 

298 Get conversion matrix between different kinds of polynomials. The supported kinds are 

299 - T: Chebychev polynomials of first kind 

300 - U: Chebychev polynomials of second kind 

301 - D: Dirichlet recombination. 

302 

303 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

304 Once generates matrices are cached. So feel free to call the method as often as you like. 

305 

306 Args: 

307 name (str): Conversion code, e.g. 'T2U' 

308 N (int): Size of the matrix (optional) 

309 

310 Returns: 

311 scipy.sparse: Sparse conversion matrix 

312 ''' 

313 N = N if N else self.N 

314 sp = self.sparse_lib 

315 

316 def get_forward_conv(name): 

317 if name == 'T2U': 

318 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0 

319 mat[:, 0] *= 2 

320 elif name == 'D2T': 

321 mat = sp.eye(N) - sp.eye(N, k=2) 

322 elif name[0] == name[-1]: 

323 mat = self.sparse_lib.eye(self.N) 

324 else: 

325 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

326 return mat 

327 

328 try: 

329 mat = get_forward_conv(name) 

330 except NotImplementedError as E: 

331 try: 

332 fwd = get_forward_conv(name[::-1]) 

333 import scipy.sparse as sp 

334 

335 if self.sparse_lib == sp: 

336 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

337 else: 

338 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

339 except NotImplementedError: 

340 raise NotImplementedError from E 

341 

342 return mat 

343 

344 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

345 """ 

346 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

347 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

348 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

349 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

350 

351 Args: 

352 conv (str): Conversion code, i.e. T2U 

353 

354 Returns: 

355 Sparse conversion matrix 

356 """ 

357 return self.get_conv(conv) 

358 

359 def get_integration_matrix(self, lbnd=0): 

360 """ 

361 Get matrix for integration 

362 

363 Args: 

364 lbnd (float): Lower bound for integration, only 0 is currently implemented 

365 

366 Returns: 

367 Sparse integration matrix 

368 """ 

369 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

370 n = self.xp.arange(self.N) 

371 if lbnd == 0: 

372 S = S.tocsc() 

373 S[0, 1::2] = ( 

374 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

375 * (-1) ** (self.xp.arange(self.N // 2)) 

376 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

377 ) * self.lin_trf_fac 

378 else: 

379 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

380 return S 

381 

382 def get_differentiation_matrix(self, p=1): 

383 ''' 

384 Keep in mind that the T2T differentiation matrix is dense. 

385 

386 Args: 

387 p (int): Derivative you want to compute 

388 

389 Returns: 

390 numpy.ndarray: Differentiation matrix 

391 ''' 

392 D = self.xp.zeros((self.N, self.N)) 

393 for j in range(self.N): 

394 for k in range(j): 

395 D[k, j] = 2 * j * ((j - k) % 2) 

396 

397 D[0, :] /= 2 

398 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

399 

400 @cache 

401 def get_norm(self, N=None): 

402 ''' 

403 Get normalization for converting Chebychev coefficients and DCT 

404 

405 Args: 

406 N (int, optional): Resolution 

407 

408 Returns: 

409 self.xp.array: Normalization 

410 ''' 

411 N = self.N if N is None else N 

412 norm = self.xp.ones(N) / N 

413 norm[0] /= 2 

414 return norm 

415 

416 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

417 """ 

418 DCT along axes. `kwargs` will be passed on to the FFT library. 

419 

420 Args: 

421 u: Data you want to transform 

422 axes (tuple): Axes you want to transform along 

423 

424 Returns: 

425 Data in spectral space 

426 """ 

427 axes = axes if axes else tuple(i for i in range(u.ndim)) 

428 kwargs['s'] = shape 

429 kwargs['norm'] = kwargs.get('norm', 'backward') 

430 

431 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs) 

432 for axis in axes: 

433 

434 if self.N < trf.shape[axis]: 

435 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

436 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies 

437 # in the middle. 

438 _trf = self.xp.zeros_like(trf) 

439 N = self.N 

440 N_pad = _trf.shape[axis] - N 

441 end_first_half = N // 2 + 1 

442 

443 # copy first "half" 

444 su = [slice(None)] * trf.ndim 

445 su[axis] = slice(0, end_first_half) 

446 _trf[tuple(su)] = trf[tuple(su)] 

447 

448 # copy second "half" 

449 su = [slice(None)] * u.ndim 

450 su[axis] = slice(end_first_half + N_pad, None) 

451 s_u = [slice(None)] * u.ndim 

452 s_u[axis] = slice(end_first_half, N) 

453 _trf[tuple(su)] = trf[tuple(s_u)] 

454 

455 # # copy values to be cut 

456 # su = [slice(None)] * u.ndim 

457 # su[axis] = slice(end_first_half, end_first_half + N_pad) 

458 # s_u = [slice(None)] * u.ndim 

459 # s_u[axis] = slice(-N_pad, None) 

460 # _trf[tuple(su)] = trf[tuple(s_u)] 

461 

462 trf = _trf 

463 

464 expansion = [np.newaxis for _ in u.shape] 

465 expansion[axis] = slice(0, u.shape[axis], 1) 

466 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1] 

467 norm[: self.N] = self.norm 

468 trf *= norm[(*expansion,)] 

469 return trf 

470 

471 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

472 """ 

473 Inverse DCT along axis. 

474 

475 Args: 

476 u: Data you want to transform 

477 axes (tuple): Axes you want to transform along 

478 

479 Returns: 

480 Data in physical space 

481 """ 

482 axes = axes if axes else tuple(i for i in range(u.ndim)) 

483 kwargs['s'] = shape 

484 kwargs['norm'] = kwargs.get('norm', 'backward') 

485 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False) 

486 

487 for axis in axes: 

488 

489 if self.N == u.shape[axis]: 

490 _u = u.copy() 

491 else: 

492 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

493 # added in the middle rather than the end. We need to resort this here and put the padding in the end. 

494 N = self.N 

495 _u = self.xp.zeros_like(u) 

496 

497 # copy first half 

498 su = [slice(None)] * u.ndim 

499 su[axis] = slice(0, N // 2 + 1) 

500 _u[tuple(su)] = u[tuple(su)] 

501 

502 # copy second half 

503 su = [slice(None)] * u.ndim 

504 su[axis] = slice(-(N // 2), None) 

505 s_u = [slice(None)] * u.ndim 

506 s_u[axis] = slice(N // 2, N // 2 + (N // 2)) 

507 _u[tuple(s_u)] = u[tuple(su)] 

508 

509 if N % 2 == 0: 

510 su = [slice(None)] * u.ndim 

511 su[axis] = N // 2 

512 _u[tuple(su)] *= 2 

513 

514 # generate norm 

515 expansion = [np.newaxis for _ in u.shape] 

516 expansion[axis] = slice(0, u.shape[axis], 1) 

517 norm = self.xp.ones(_u.shape[axis]) 

518 norm[: self.N] = self.norm 

519 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N 

520 

521 _u /= norm[(*expansion,)] 

522 

523 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs) 

524 

525 def get_BC(self, kind, **kwargs): 

526 """ 

527 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

528 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

529 set the BC. 

530 

531 Args: 

532 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

533 """ 

534 if kind.lower() == 'integral': 

535 return self.get_integ_BC_row(**kwargs) 

536 elif kind.lower() == 'dirichlet': 

537 return self.get_Dirichlet_BC_row(**kwargs) 

538 else: 

539 return super().get_BC(kind) 

540 

541 def get_integ_BC_row(self): 

542 """ 

543 Get a row for generating integral BCs with T polynomials. 

544 It returns the values of the integrals of T polynomials over the entire interval. 

545 

546 Returns: 

547 self.xp.ndarray: Row to put into a matrix 

548 """ 

549 n = self.xp.arange(self.N) + 1 

550 me = self.xp.zeros_like(n).astype(float) 

551 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

552 me[0] = 2.0 

553 return me 

554 

555 def get_Dirichlet_BC_row(self, x): 

556 """ 

557 Get a row for generating Dirichlet BCs at x with T polynomials. 

558 It returns the values of the T polynomials at x. 

559 

560 Args: 

561 x (float): Position of the boundary condition 

562 

563 Returns: 

564 self.xp.ndarray: Row to put into a matrix 

565 """ 

566 if x == -1: 

567 return (-1) ** self.xp.arange(self.N) 

568 elif x == 1: 

569 return self.xp.ones(self.N) 

570 elif x == 0: 

571 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

572 n[2::4] *= -1 

573 return n 

574 else: 

575 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

576 

577 def get_Dirichlet_recombination_matrix(self): 

578 ''' 

579 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

580 This makes for a good right preconditioner. 

581 

582 Returns: 

583 scipy.sparse: Sparse conversion matrix 

584 ''' 

585 N = self.N 

586 sp = self.sparse_lib 

587 

588 return sp.eye(N) - sp.eye(N, k=2) 

589 

590 

591class UltrasphericalHelper(ChebychevHelper): 

592 """ 

593 This implementation follows https://doi.org/10.1137/120865458. 

594 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

595 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

596 There are also conversion operators from one derivative basis to the next that are sparse. 

597 

598 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

599 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

600 """ 

601 

602 def get_differentiation_matrix(self, p=1): 

603 """ 

604 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

605 

606 Args: 

607 p (int): Order of the derivative 

608 

609 Returns: 

610 sparse differentiation matrix 

611 """ 

612 sp = self.sparse_lib 

613 xp = self.xp 

614 N = self.N 

615 l = p 

616 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

617 

618 def get_S(self, lmbda): 

619 """ 

620 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

621 https://doi.org/10.1137/120865458. 

622 

623 Args: 

624 lmbda (int): Ingoing derivative base 

625 

626 Returns: 

627 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

628 """ 

629 N = self.N 

630 

631 if lmbda == 0: 

632 sp = scipy.sparse 

633 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil() 

634 mat[:, 0] *= 2 

635 else: 

636 sp = self.sparse_lib 

637 xp = self.xp 

638 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

639 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

640 ) 

641 

642 return self.sparse_lib.csc_matrix(mat) 

643 

644 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

645 """ 

646 Get a conversion matrix from derivative base `p_in` to `p_out`. 

647 

648 Args: 

649 p_out (int): Resulting derivative base 

650 p_in (int): Ingoing derivative base 

651 """ 

652 mat_fwd = self.sparse_lib.eye(self.N) 

653 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

654 mat_fwd = self.get_S(i) @ mat_fwd 

655 

656 if p_out > p_in: 

657 return mat_fwd 

658 

659 else: 

660 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

661 import scipy.sparse as sp 

662 

663 if self.useGPU: 

664 mat_fwd = mat_fwd.get() 

665 

666 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

667 

668 return self.sparse_lib.csc_matrix(mat_bck) 

669 

670 def get_integration_matrix(self): 

671 """ 

672 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

673 integration constant such that integration starts from x=-1. 

674 

675 Example: 

676 

677 .. code-block:: python 

678 

679 import numpy as np 

680 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

681 

682 N = 4 

683 helper = UltrasphericalHelper(N) 

684 coeffs = np.random.random(N) 

685 coeffs[-1] = 0 

686 

687 poly = np.polynomial.Chebyshev(coeffs) 

688 

689 S = helper.get_integration_matrix() 

690 U_hat = S @ coeffs 

691 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

692 

693 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

694 

695 Returns: 

696 sparse integration matrix 

697 """ 

698 return ( 

699 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

700 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

701 * self.lin_trf_fac 

702 ) 

703 

704 def get_integration_constant(self, u_hat, axis): 

705 """ 

706 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

707 

708 Args: 

709 u_hat: Solution in spectral space 

710 axis: Axis you want to integrate over 

711 

712 Returns: 

713 Integration constant, has one less dimension than `u_hat` 

714 """ 

715 slices = [ 

716 None, 

717 ] * u_hat.ndim 

718 slices[axis] = slice(1, u_hat.shape[axis]) 

719 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

720 

721 

722class FFTHelper(SpectralHelper1D): 

723 distributable = True 

724 

725 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

726 """ 

727 Constructor. 

728 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

729 may choose to run on GPUs via the `useGPU` argument. 

730 

731 Args: 

732 x0 (float, optional): Coordinate of left boundary 

733 x1 (float, optional): Coordinate of right boundary 

734 """ 

735 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

736 

737 def get_1dgrid(self): 

738 """ 

739 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

740 """ 

741 dx = self.L / self.N 

742 return self.xp.arange(self.N) * dx + self.x0 

743 

744 def get_wavenumbers(self): 

745 """ 

746 Be careful that this ordering is very unintuitive. 

747 """ 

748 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

749 

750 def get_differentiation_matrix(self, p=1): 

751 """ 

752 This matrix is diagonal, allowing to invert concurrently. 

753 

754 Args: 

755 p (int): Order of the derivative 

756 

757 Returns: 

758 sparse differentiation matrix 

759 """ 

760 k = self.get_wavenumbers() 

761 

762 if self.useGPU: 

763 if p > 1: 

764 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

765 from scipy.sparse.linalg import matrix_power 

766 

767 D = self.sparse_lib.diags(1j * k).get() 

768 return self.sparse_lib.csc_matrix(matrix_power(D, p)) 

769 else: 

770 return self.sparse_lib.diags(1j * k) 

771 else: 

772 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

773 

774 def get_integration_matrix(self, p=1): 

775 """ 

776 Get integration matrix to compute `p`-th integral over the entire domain. 

777 

778 Args: 

779 p (int): Order of integral you want to compute 

780 

781 Returns: 

782 sparse integration matrix 

783 """ 

784 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

785 k[0] = 1j * self.L 

786 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

787 

788 def get_plan(self, u, forward, *args, **kwargs): 

789 if self.fft_lib.__name__ == 'mpi4py_fft.fftw': 

790 if 'axes' in kwargs.keys(): 

791 kwargs['axes'] = tuple(kwargs['axes']) 

792 key = (forward, u.shape, args, *(me for me in kwargs.values())) 

793 if key in self.plans.keys(): 

794 return self.plans[key] 

795 else: 

796 self.logger.debug(f'Generating FFT plan for {key=}') 

797 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs) 

798 self.plans[key] = transform 

799 

800 return self.plans[key] 

801 else: 

802 if forward: 

803 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward')) 

804 else: 

805 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward')) 

806 

807 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

808 """ 

809 FFT along axes. `kwargs` are passed on to the FFT library. 

810 

811 Args: 

812 u: Data you want to transform 

813 axes (tuple): Axes you want to transform over 

814 

815 Returns: 

816 transformed data 

817 """ 

818 axes = axes if axes else tuple(i for i in range(u.ndim)) 

819 kwargs['s'] = shape 

820 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs) 

821 return plan(u, *args, axes=axes, **kwargs) 

822 

823 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

824 """ 

825 Inverse FFT. 

826 

827 Args: 

828 u: Data you want to transform 

829 axes (tuple): Axes over which to transform 

830 

831 Returns: 

832 transformed data 

833 """ 

834 axes = axes if axes else tuple(i for i in range(u.ndim)) 

835 kwargs['s'] = shape 

836 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs) 

837 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes]) 

838 

839 def get_BC(self, kind): 

840 """ 

841 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

842 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

843 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

844 

845 Args: 

846 kind ('integral' or 'nyquist'): Kind of BC 

847 

848 Returns: 

849 self.xp.ndarray: Boundary condition row 

850 """ 

851 if kind.lower() == 'integral': 

852 return self.get_integ_BC_row() 

853 elif kind.lower() == 'nyquist': 

854 assert ( 

855 self.N % 2 == 0 

856 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

857 BC = self.xp.zeros(self.N) 

858 BC[self.get_Nyquist_mode_index()] = 1 

859 return BC 

860 else: 

861 return super().get_BC(kind) 

862 

863 def get_Nyquist_mode_index(self): 

864 """ 

865 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

866 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

867 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

868 after. 

869 

870 Returns: 

871 int: Index of the Nyquist mode 

872 """ 

873 k = self.get_wavenumbers() 

874 Nyquist_mode = min(k) 

875 return self.xp.where(k == Nyquist_mode)[0][0] 

876 

877 def get_integ_BC_row(self): 

878 """ 

879 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

880 """ 

881 me = self.xp.zeros(self.N) 

882 me[0] = self.L / self.N 

883 return me 

884 

885 

886class SpectralHelper: 

887 """ 

888 This class has three functions: 

889 - Easily assemble matrices containing multiple equations 

890 - Direct product of 1D bases to solve problems in more dimensions 

891 - Distribute the FFTs to facilitate concurrency. 

892 

893 Attributes: 

894 comm (mpi4py.Intracomm): MPI communicator 

895 debug (bool): Perform additional checks at extra computational cost 

896 useGPU (bool): Whether to use GPUs 

897 axes (list): List of 1D bases 

898 components (list): List of strings of the names of components in the equations 

899 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

900 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

901 BCs (sparse matrix): Matrix containing only the BCs 

902 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

903 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

904 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

905 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

906 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

907 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

908 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

909 init (tuple): This is the same `init` that is used throughout the problem classes 

910 init_forward (tuple): This is the equivalent of `init` in spectral space 

911 """ 

912 

913 xp = np 

914 fft_lib = scipy.fft 

915 sparse_lib = scipy.sparse 

916 linalg = scipy.sparse.linalg 

917 dtype = mesh 

918 fft_backend = 'scipy' 

919 fft_comm_backend = 'MPI' 

920 

921 @classmethod 

922 def setup_GPU(cls): 

923 """switch to GPU modules""" 

924 import cupy as cp 

925 import cupyx.scipy.sparse as sparse_lib 

926 import cupyx.scipy.sparse.linalg as linalg 

927 import cupyx.scipy.fft as fft_lib 

928 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

929 

930 cls.xp = cp 

931 cls.sparse_lib = sparse_lib 

932 cls.linalg = linalg 

933 

934 cls.fft_lib = fft_lib 

935 cls.fft_backend = 'cupyx-scipy' 

936 cls.fft_comm_backend = 'NCCL' 

937 

938 cls.dtype = cupy_mesh 

939 

940 @classmethod 

941 def setup_CPU(cls, useFFTW=False): 

942 """switch to CPU modules""" 

943 

944 cls.xp = np 

945 cls.sparse_lib = scipy.sparse 

946 cls.linalg = scipy.sparse.linalg 

947 

948 if useFFTW: 

949 from mpi4py_fft import fftw 

950 

951 cls.fft_backend = 'fftw' 

952 cls.fft_lib = fftw 

953 else: 

954 cls.fft_backend = 'scipy' 

955 cls.fft_lib = scipy.fft 

956 

957 cls.fft_comm_backend = 'MPI' 

958 cls.dtype = mesh 

959 

960 def __init__(self, comm=None, useGPU=False, debug=False): 

961 """ 

962 Constructor 

963 

964 Args: 

965 comm (mpi4py.Intracomm): MPI communicator 

966 useGPU (bool): Whether to use GPUs 

967 debug (bool): Perform additional checks at extra computational cost 

968 """ 

969 self.comm = comm 

970 self.debug = debug 

971 self.useGPU = useGPU 

972 

973 if useGPU: 

974 self.setup_GPU() 

975 else: 

976 self.setup_CPU() 

977 

978 self.axes = [] 

979 self.components = [] 

980 

981 self.full_BCs = [] 

982 self.BC_mat = None 

983 self.BCs = None 

984 

985 self.fft_cache = {} 

986 self.fft_dealias_shape_cache = {} 

987 

988 self.logger = logging.getLogger(name='Spectral Discretization') 

989 if debug: 

990 self.logger.setLevel(logging.DEBUG) 

991 

992 @property 

993 def u_init(self): 

994 """ 

995 Get empty data container in physical space 

996 """ 

997 return self.dtype(self.init) 

998 

999 @property 

1000 def u_init_forward(self): 

1001 """ 

1002 Get empty data container in spectral space 

1003 """ 

1004 return self.dtype(self.init_forward) 

1005 

1006 @property 

1007 def u_init_physical(self): 

1008 """ 

1009 Get empty data container in physical space 

1010 """ 

1011 return self.dtype(self.init_physical) 

1012 

1013 @property 

1014 def shape(self): 

1015 """ 

1016 Get shape of individual solution component 

1017 """ 

1018 return self.init[0][1:] 

1019 

1020 @property 

1021 def ndim(self): 

1022 return len(self.axes) 

1023 

1024 @property 

1025 def ncomponents(self): 

1026 return len(self.components) 

1027 

1028 @property 

1029 def V(self): 

1030 """ 

1031 Get domain volume 

1032 """ 

1033 return np.prod([me.L for me in self.axes]) 

1034 

1035 def add_axis(self, base, *args, **kwargs): 

1036 """ 

1037 Add an axis to the domain by deciding on suitable 1D base. 

1038 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

1039 bases for possible arguments. 

1040 

1041 Args: 

1042 base (str): 1D spectral method 

1043 """ 

1044 kwargs['useGPU'] = self.useGPU 

1045 

1046 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

1047 self.axes.append(ChebychevHelper(*args, **kwargs)) 

1048 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

1049 self.axes.append(FFTHelper(*args, **kwargs)) 

1050 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

1051 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

1052 else: 

1053 raise NotImplementedError(f'{base=!r} is not implemented!') 

1054 self.axes[-1].xp = self.xp 

1055 self.axes[-1].sparse_lib = self.sparse_lib 

1056 

1057 def add_component(self, name): 

1058 """ 

1059 Add solution component(s). 

1060 

1061 Args: 

1062 name (str or list of strings): Name(s) of component(s) 

1063 """ 

1064 if type(name) in [list, tuple]: 

1065 for me in name: 

1066 self.add_component(me) 

1067 elif type(name) in [str]: 

1068 if name in self.components: 

1069 raise Exception(f'{name=!r} is already added to this problem!') 

1070 self.components.append(name) 

1071 else: 

1072 raise NotImplementedError 

1073 

1074 def index(self, name): 

1075 """ 

1076 Get the index of component `name`. 

1077 

1078 Args: 

1079 name (str or list of strings): Name(s) of component(s) 

1080 

1081 Returns: 

1082 int: Index of the component 

1083 """ 

1084 if type(name) in [str, int]: 

1085 return self.components.index(name) 

1086 elif type(name) in [list, tuple]: 

1087 return (self.index(me) for me in name) 

1088 else: 

1089 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

1090 

1091 def get_empty_operator_matrix(self, diag=False): 

1092 """ 

1093 Return a matrix of operators to be filled with the connections between the solution components. 

1094 

1095 Args: 

1096 diag (bool): Whether operator is block-diagonal 

1097 

1098 Returns: 

1099 list containing sparse zeros 

1100 """ 

1101 S = len(self.components) 

1102 O = self.get_Id() * 0 

1103 if diag: 

1104 return [O for _ in range(S)] 

1105 else: 

1106 return [[O for _ in range(S)] for _ in range(S)] 

1107 

1108 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

1109 """ 

1110 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

1111 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

1112 Typically, the last line that does not contain a BC is the best choice. 

1113 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1114 

1115 Args: 

1116 axis (int): Axis you want to add the BC to 

1117 kind (str): kind of BC, e.g. Dirichlet 

1118 line (int): Line you want the BC to go in 

1119 scalar (bool): Put the BC in all space positions in the other direction 

1120 

1121 Returns: 

1122 sparse matrix containing the BC 

1123 """ 

1124 sp = scipy.sparse 

1125 

1126 base = self.axes[axis] 

1127 

1128 BC = sp.eye(base.N).tolil() * 0 

1129 if self.useGPU: 

1130 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1131 else: 

1132 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1133 

1134 ndim = len(self.axes) 

1135 if ndim == 1: 

1136 mat = self.sparse_lib.csc_matrix(BC) 

1137 elif ndim == 2: 

1138 axis2 = (axis + 1) % ndim 

1139 

1140 if scalar: 

1141 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1142 else: 

1143 _Id = self.axes[axis2].get_Id() 

1144 

1145 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1146 

1147 mats = [ 

1148 None, 

1149 ] * ndim 

1150 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1151 mats[axis2] = Id 

1152 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) 

1153 elif ndim == 3: 

1154 mats = [ 

1155 None, 

1156 ] * ndim 

1157 

1158 for ax in range(ndim): 

1159 if ax == axis: 

1160 continue 

1161 

1162 if scalar: 

1163 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1))) 

1164 else: 

1165 _Id = self.axes[ax].get_Id() 

1166 

1167 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax) 

1168 

1169 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1170 

1171 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) 

1172 else: 

1173 raise NotImplementedError( 

1174 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!' 

1175 ) 

1176 mat = self.eliminate_zeros(mat) 

1177 return mat 

1178 

1179 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1180 """ 

1181 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1182 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1183 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1184 

1185 Args: 

1186 component (str): Name of the component the BC should act on 

1187 equation (str): Name of the equation for the component you want to put the BC in 

1188 axis (int): Axis you want to add the BC to 

1189 kind (str): kind of BC, e.g. Dirichlet 

1190 v: Value of the BC 

1191 line (int): Line you want the BC to go in 

1192 scalar (bool): Put the BC in all space positions in the other direction 

1193 """ 

1194 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1195 _BC = self.eliminate_zeros(_BC) 

1196 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1197 

1198 if scalar: 

1199 slices = [self.index(equation)] + [ 

1200 0, 

1201 ] * self.ndim 

1202 slices[axis + 1] = line 

1203 else: 

1204 slices = ( 

1205 [self.index(equation)] 

1206 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1207 + [line] 

1208 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1209 ) 

1210 N = self.axes[axis].N 

1211 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]: 

1212 self.BC_rhs_mask[(*slices,)] = False 

1213 

1214 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1215 """ 

1216 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1217 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1218 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1219 

1220 Args: 

1221 component (str): Name of the component the BC should act on 

1222 equation (str): Name of the equation for the component you want to put the BC in 

1223 axis (int): Axis you want to add the BC to 

1224 kind (str): kind of BC, e.g. Dirichlet 

1225 v: Value of the BC 

1226 line (int): Line you want the BC to go in 

1227 scalar (bool): Put the BC in all space positions in the other direction 

1228 """ 

1229 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1230 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1231 self.full_BCs += [ 

1232 { 

1233 'component': component, 

1234 'equation': equation, 

1235 'axis': axis, 

1236 'kind': kind, 

1237 'v': v, 

1238 'line': line, 

1239 'scalar': scalar, 

1240 **kwargs, 

1241 } 

1242 ] 

1243 

1244 if scalar: 

1245 slices = [self.index(equation)] + [ 

1246 0, 

1247 ] * self.ndim 

1248 slices[axis + 1] = line 

1249 if self.comm: 

1250 if self.comm.rank == 0: 

1251 self.BC_rhs_mask[(*slices,)] = True 

1252 else: 

1253 self.BC_rhs_mask[(*slices,)] = True 

1254 else: 

1255 slices = [self.index(equation), *self.global_slice(True)] 

1256 N = self.axes[axis].N 

1257 if (N + line) % N in self.get_indices(True)[axis]: 

1258 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1259 self.BC_rhs_mask[(*slices,)] = True 

1260 

1261 def setup_BCs(self): 

1262 """ 

1263 Convert the list of lists of BCs to the boundary condition operator. 

1264 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1265 condition. This method sets up a suitable sparse matrix to do this. 

1266 """ 

1267 sp = self.sparse_lib 

1268 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1269 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1270 

1271 diags = self.xp.ones(self.BCs.shape[0]) 

1272 diags[self.BC_zero_index] = 0 

1273 self.BC_line_zero_matrix = sp.diags(diags) 

1274 

1275 # prepare BCs in spectral space to easily add to the RHS 

1276 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1277 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1278 

1279 def check_BCs(self, u): 

1280 """ 

1281 Check that the solution satisfies the boundary conditions 

1282 

1283 Args: 

1284 u: The solution you want to check 

1285 """ 

1286 assert self.ndim < 3 

1287 for axis in range(self.ndim): 

1288 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1289 

1290 if len(BCs) > 0: 

1291 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1292 for BC in BCs: 

1293 kwargs = { 

1294 key: value 

1295 for key, value in BC.items() 

1296 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1297 } 

1298 

1299 if axis == 0: 

1300 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1301 elif axis == 1: 

1302 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1303 want = BC['v'] 

1304 assert self.xp.allclose( 

1305 get, want 

1306 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1307 

1308 def put_BCs_in_matrix(self, A): 

1309 """ 

1310 Put the boundary conditions in a matrix by replacing rows with BCs. 

1311 """ 

1312 return self.BC_line_zero_matrix @ A + self.BCs 

1313 

1314 def put_BCs_in_rhs_hat(self, rhs_hat): 

1315 """ 

1316 Put the BCs in the right hand side in spectral space for solving. 

1317 This function needs no transforms and caches a mask for faster subsequent use. 

1318 

1319 Args: 

1320 rhs_hat: Right hand side in spectral space 

1321 

1322 Returns: 

1323 rhs in spectral space with BCs 

1324 """ 

1325 if not hasattr(self, '_rhs_hat_zero_mask'): 

1326 """ 

1327 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1328 by the boundary conditions. The mask is then cached. 

1329 """ 

1330 self._rhs_hat_zero_mask = self.newDistArray().astype(bool) 

1331 

1332 for axis in range(self.ndim): 

1333 for bc in self.full_BCs: 

1334 if axis == bc['axis']: 

1335 slices = [self.index(bc['equation']), *self.global_slice(True)] 

1336 N = self.axes[axis].N 

1337 line = bc['line'] 

1338 if (N + line) % N in self.get_indices(True)[axis]: 

1339 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1340 self._rhs_hat_zero_mask[(*slices,)] = True 

1341 

1342 rhs_hat[self._rhs_hat_zero_mask] = 0 

1343 return rhs_hat + self.rhs_BCs_hat 

1344 

1345 def put_BCs_in_rhs(self, rhs): 

1346 """ 

1347 Put the BCs in the right hand side for solving. 

1348 This function will transform along each axis individually and add all BCs in that axis. 

1349 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1350 

1351 Args: 

1352 rhs: Right hand side in physical space 

1353 

1354 Returns: 

1355 rhs in physical space with BCs 

1356 """ 

1357 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1358 

1359 ndim = self.ndim 

1360 

1361 for axis in range(ndim): 

1362 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1363 

1364 for bc in self.full_BCs: 

1365 

1366 if axis == bc['axis']: 

1367 _slice = [self.index(bc['equation']), *self.global_slice(True)] 

1368 

1369 N = self.axes[axis].N 

1370 line = bc['line'] 

1371 if (N + line) % N in self.get_indices(True)[axis]: 

1372 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1373 _rhs_hat[(*_slice,)] = bc['v'] 

1374 

1375 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1376 

1377 return rhs 

1378 

1379 def add_equation_lhs(self, A, equation, relations): 

1380 """ 

1381 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1382 that you will convert to an operator later. 

1383 

1384 Example: 

1385 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1386 preconditioning: 

1387 

1388 .. code-block:: python 

1389 helper = SpectralHelper() 

1390 

1391 helper.add_axis(base='chebychev', N=8) 

1392 helper.add_component(['u', 'ux']) 

1393 helper.setup_fft() 

1394 

1395 I = helper.get_Id() 

1396 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1397 T2U = helper.get_basis_change_matrix('T2U') 

1398 

1399 L_lhs = { 

1400 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1401 'u': {'ux': -(T2U @ Dx)}, 

1402 } 

1403 

1404 operator = helper.get_empty_operator_matrix() 

1405 for line, equation in L_lhs.items(): 

1406 helper.add_equation_lhs(operator, line, equation) 

1407 

1408 L = helper.convert_operator_matrix_to_operator(operator) 

1409 

1410 Args: 

1411 A (list of lists of sparse matrices): The operator to be 

1412 equation (str): The equation of the component you want this in 

1413 relations: (dict): Relations between quantities 

1414 """ 

1415 for k, v in relations.items(): 

1416 A[self.index(equation)][self.index(k)] = v 

1417 

1418 def eliminate_zeros(self, A): 

1419 """ 

1420 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat. 

1421 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`. 

1422 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU. 

1423 

1424 Args: 

1425 A: sparse matrix to be pruned 

1426 

1427 Returns: 

1428 CSC sparse matrix 

1429 """ 

1430 if self.useGPU: 

1431 A = A.get() 

1432 A = A.tocsc() 

1433 A.eliminate_zeros() 

1434 if self.useGPU: 

1435 A = self.sparse_lib.csc_matrix(A) 

1436 return A 

1437 

1438 def convert_operator_matrix_to_operator(self, M): 

1439 """ 

1440 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1441 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1442 

1443 Args: 

1444 M (list of lists of sparse matrices): The operator to be 

1445 

1446 Returns: 

1447 sparse linear operator 

1448 """ 

1449 if len(self.components) == 1: 

1450 op = M[0][0] 

1451 else: 

1452 op = self.sparse_lib.bmat(M, format='csc') 

1453 

1454 op = self.eliminate_zeros(op) 

1455 return op 

1456 

1457 def get_wavenumbers(self): 

1458 """ 

1459 Get grid in spectral space 

1460 """ 

1461 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))] 

1462 return self.xp.meshgrid(*grids, indexing='ij') 

1463 

1464 def get_grid(self, forward_output=False): 

1465 """ 

1466 Get grid in physical space 

1467 """ 

1468 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1469 return self.xp.meshgrid(*grids, indexing='ij') 

1470 

1471 def get_indices(self, forward_output=True): 

1472 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1473 

1474 @cache 

1475 def get_pfft(self, axes=None, padding=None, grid=None): 

1476 if self.ndim == 1 or self.comm is None: 

1477 return None 

1478 from mpi4py_fft import PFFT 

1479 

1480 axes = tuple(i for i in range(self.ndim)) if axes is None else axes 

1481 padding = list(padding if padding else [1.0 for _ in range(self.ndim)]) 

1482 

1483 def no_transform(u, *args, **kwargs): 

1484 return u 

1485 

1486 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)} 

1487 for i in axes: 

1488 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform) 

1489 

1490 # "transform" all axes to ensure consistent shapes. 

1491 # Transform non-distributable axes last to ensure they are aligned 

1492 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes)) 

1493 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted( 

1494 [axis for axis in _axes if self.axes[axis].distributable] 

1495 + [axis for axis in range(self.ndim) if axis not in _axes] 

1496 ) 

1497 

1498 pfft = PFFT( 

1499 comm=self.comm, 

1500 shape=self.global_shape[1:], 

1501 axes=_axes, # TODO: control the order of the transforms better 

1502 dtype='D', 

1503 collapse=False, 

1504 backend=self.fft_backend, 

1505 comm_backend=self.fft_comm_backend, 

1506 padding=padding, 

1507 transforms=transforms, 

1508 grid=grid, 

1509 ) 

1510 return pfft 

1511 

1512 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1513 """ 

1514 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1515 

1516 Args: 

1517 axes (tuple): Axes you want to transform over 

1518 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1519 padding (tuple): Padding for dealiasing 

1520 shape (tuple): Shape of the transform 

1521 

1522 Returns: 

1523 transform 

1524 """ 

1525 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1526 shape = self.global_shape[1:] if shape is None else shape 

1527 padding = ( 

1528 [ 

1529 1, 

1530 ] 

1531 * self.ndim 

1532 if padding is None 

1533 else padding 

1534 ) 

1535 key = (axes, direction, tuple(padding), tuple(shape)) 

1536 

1537 if key not in self.fft_cache.keys(): 

1538 if self.comm is None: 

1539 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1540 

1541 if direction == 'forward': 

1542 self.fft_cache[key] = self.xp.fft.fftn 

1543 elif direction == 'backward': 

1544 self.fft_cache[key] = self.xp.fft.ifftn 

1545 elif direction == 'object': 

1546 self.fft_cache[key] = None 

1547 else: 

1548 if direction == 'object': 

1549 from mpi4py_fft import PFFT 

1550 

1551 _fft = PFFT( 

1552 comm=self.comm, 

1553 shape=shape, 

1554 axes=sorted(axes), 

1555 dtype='D', 

1556 collapse=False, 

1557 backend=self.fft_backend, 

1558 comm_backend=self.fft_comm_backend, 

1559 padding=padding, 

1560 ) 

1561 else: 

1562 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1563 

1564 if direction == 'forward': 

1565 self.fft_cache[key] = _fft.forward 

1566 elif direction == 'backward': 

1567 self.fft_cache[key] = _fft.backward 

1568 elif direction == 'object': 

1569 self.fft_cache[key] = _fft 

1570 

1571 return self.fft_cache[key] 

1572 

1573 def local_slice(self, forward_output=True): 

1574 if self.fft_obj: 

1575 return self.get_pfft().local_slice(forward_output=forward_output) 

1576 else: 

1577 return [slice(0, me.N) for me in self.axes] 

1578 

1579 def global_slice(self, forward_output=True): 

1580 if self.fft_obj: 

1581 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)] 

1582 else: 

1583 return self.local_slice(forward_output=forward_output) 

1584 

1585 def setup_fft(self, real_spectral_coefficients=False): 

1586 """ 

1587 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1588 This must also be called before setting up any BCs. 

1589 

1590 Args: 

1591 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1592 """ 

1593 if len(self.components) == 0: 

1594 self.add_component('u') 

1595 

1596 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1597 

1598 axes = tuple(i for i in range(len(self.axes))) 

1599 self.fft_obj = self.get_pfft(axes=axes) 

1600 

1601 self.init = ( 

1602 np.empty(shape=self.global_shape)[ 

1603 ( 

1604 ..., 

1605 *self.local_slice(False), 

1606 ) 

1607 ].shape, 

1608 self.comm, 

1609 np.dtype('float'), 

1610 ) 

1611 self.init_physical = ( 

1612 np.empty(shape=self.global_shape)[ 

1613 ( 

1614 ..., 

1615 *self.local_slice(False), 

1616 ) 

1617 ].shape, 

1618 self.comm, 

1619 np.dtype('float'), 

1620 ) 

1621 self.init_forward = ( 

1622 np.empty(shape=self.global_shape)[ 

1623 ( 

1624 ..., 

1625 *self.local_slice(True), 

1626 ) 

1627 ].shape, 

1628 self.comm, 

1629 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1630 ) 

1631 

1632 self.BC_mat = self.get_empty_operator_matrix() 

1633 self.BC_rhs_mask = self.newDistArray().astype(bool) 

1634 

1635 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False): 

1636 """ 

1637 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but 

1638 takes care of all the solution components in the tensor. 

1639 """ 

1640 if self.comm is None: 

1641 return self.xp.zeros(self.init[0], dtype=self.init[2]) 

1642 from mpi4py_fft.distarray import DistArray 

1643 

1644 pfft = pfft if pfft else self.get_pfft() 

1645 if pfft is None: 

1646 if forward_output: 

1647 return self.u_init_forward 

1648 else: 

1649 return self.u_init 

1650 

1651 global_shape = pfft.global_shape(forward_output) 

1652 p0 = pfft.pencil[forward_output] 

1653 if forward_output is True: 

1654 dtype = pfft.forward.output_array.dtype 

1655 else: 

1656 dtype = pfft.forward.input_array.dtype 

1657 global_shape = (self.ncomponents,) * rank + global_shape 

1658 

1659 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: 

1660 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls 

1661 else: 

1662 darraycls = DistArray 

1663 

1664 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank) 

1665 return z.v if view else z 

1666 

1667 def infer_alignment(self, u, forward_output, padding=None, **kwargs): 

1668 if self.comm is None: 

1669 return [0] 

1670 

1671 def _alignment(pfft): 

1672 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1673 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]] 

1674 return _aligned_axes 

1675 

1676 if padding is None: 

1677 pfft = self.get_pfft(**kwargs) 

1678 aligned_axes = _alignment(pfft) 

1679 else: 

1680 if self.ndim == 2: 

1681 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)] 

1682 elif self.ndim == 3: 

1683 padding_options = [ 

1684 (1.0, 1.0, padding[2]), 

1685 (1.0, padding[1], 1.0), 

1686 (padding[0], 1.0, 1.0), 

1687 (1.0, padding[1], padding[2]), 

1688 (padding[0], 1.0, padding[2]), 

1689 (padding[0], padding[1], 1.0), 

1690 padding, 

1691 (1.0, 1.0, 1.0), 

1692 ] 

1693 else: 

1694 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!') 

1695 for _padding in padding_options: 

1696 pfft = self.get_pfft(padding=_padding, **kwargs) 

1697 aligned_axes = _alignment(pfft) 

1698 if len(aligned_axes) > 0: 

1699 self.logger.debug( 

1700 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}' 

1701 ) 

1702 break 

1703 

1704 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!' 

1705 return aligned_axes 

1706 

1707 def redistribute(self, u, axis, forward_output, **kwargs): 

1708 if self.comm is None: 

1709 return u 

1710 

1711 pfft = self.get_pfft(**kwargs) 

1712 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1713 

1714 if 'Dist' in type(u).__name__ and False: 

1715 try: 

1716 u.redistribute(out=_arr) 

1717 return _arr 

1718 except AssertionError: 

1719 pass 

1720 

1721 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs) 

1722 for alignment in u_alignment: 

1723 _arr = _arr.redistribute(alignment) 

1724 if _arr.shape == u.shape: 

1725 _arr[...] = u 

1726 return _arr.redistribute(axis) 

1727 

1728 raise Exception( 

1729 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}' 

1730 ) 

1731 

1732 def transform(self, u, *args, axes=None, padding=None, **kwargs): 

1733 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1734 

1735 if pfft is None: 

1736 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1737 u_hat = u.copy() 

1738 for i in axes: 

1739 _axis = 1 + i if i >= 0 else i 

1740 u_hat = self.axes[i].transform(u_hat, axes=(_axis,)) 

1741 return u_hat 

1742 

1743 _in = self.newDistArray(pfft, forward_output=False, rank=1) 

1744 _out = self.newDistArray(pfft, forward_output=True, rank=1) 

1745 

1746 if _in.shape == u.shape: 

1747 _in[...] = u 

1748 else: 

1749 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs) 

1750 

1751 for i in range(self.ncomponents): 

1752 pfft.forward(_in[i], _out[i], normalize=False) 

1753 

1754 if padding is not None: 

1755 _out /= np.prod(padding) 

1756 return _out 

1757 

1758 def itransform(self, u, *args, axes=None, padding=None, **kwargs): 

1759 if padding is not None: 

1760 assert all( 

1761 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim) 

1762 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer' 

1763 

1764 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1765 if pfft is None: 

1766 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1767 u_hat = u.copy() 

1768 for i in axes: 

1769 _axis = 1 + i if i >= 0 else i 

1770 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,)) 

1771 return u_hat 

1772 

1773 _in = self.newDistArray(pfft, forward_output=True, rank=1) 

1774 _out = self.newDistArray(pfft, forward_output=False, rank=1) 

1775 

1776 if _in.shape == u.shape: 

1777 _in[...] = u 

1778 else: 

1779 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs) 

1780 

1781 for i in range(self.ncomponents): 

1782 pfft.backward(_in[i], _out[i], normalize=True) 

1783 

1784 if padding is not None: 

1785 _out *= np.prod(padding) 

1786 return _out 

1787 

1788 def get_local_slice_of_1D_matrix(self, M, axis): 

1789 """ 

1790 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1791 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can 

1792 use this method to get the part corresponding to the modes carried by this rank. 

1793 

1794 Args: 

1795 M (sparse matrix): Global 1D matrix you want to get the local version of 

1796 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. 

1797 

1798 Returns: 

1799 sparse local matrix 

1800 """ 

1801 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]] 

1802 

1803 def expand_matrix_ND(self, matrix, aligned): 

1804 sp = self.sparse_lib 

1805 axes = np.delete(np.arange(self.ndim), aligned) 

1806 ndim = len(axes) + 1 

1807 

1808 if ndim == 1: 

1809 mat = matrix 

1810 elif ndim == 2: 

1811 axis = axes[0] 

1812 I1D = sp.eye(self.axes[axis].N) 

1813 

1814 mats = [None] * ndim 

1815 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1816 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1817 

1818 mat = sp.kron(*mats) 

1819 elif ndim == 3: 

1820 

1821 mats = [None] * ndim 

1822 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1823 for axis in axes: 

1824 I1D = sp.eye(self.axes[axis].N) 

1825 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1826 

1827 mat = sp.kron(mats[0], sp.kron(*mats[1:])) 

1828 

1829 else: 

1830 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') 

1831 

1832 mat = self.eliminate_zeros(mat) 

1833 return mat 

1834 

1835 def get_filter_matrix(self, axis, **kwargs): 

1836 """ 

1837 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1838 admissible. 

1839 

1840 Returns: 

1841 sparse bandpass matrix 

1842 """ 

1843 if self.ndim == 1: 

1844 return self.axes[0].get_filter_matrix(**kwargs) 

1845 

1846 mats = [base.get_Id() for base in self.axes] 

1847 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1848 return self.sparse_lib.kron(*mats) 

1849 

1850 def get_differentiation_matrix(self, axes, **kwargs): 

1851 """ 

1852 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1853 

1854 Args: 

1855 axes (tuple): Axes along which to differentiate. 

1856 

1857 Returns: 

1858 sparse differentiation matrix 

1859 """ 

1860 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) 

1861 for axis in axes[1:]: 

1862 _D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1863 D = D @ self.expand_matrix_ND(_D, axis) 

1864 

1865 return D 

1866 

1867 def get_integration_matrix(self, axes): 

1868 """ 

1869 Get integration matrix to integrate along specified axis. 

1870 

1871 Args: 

1872 axes (tuple): Axes along which to integrate over. 

1873 

1874 Returns: 

1875 sparse integration matrix 

1876 """ 

1877 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) 

1878 for axis in axes[1:]: 

1879 _S = self.axes[axis].get_integration_matrix() 

1880 S = S @ self.expand_matrix_ND(_S, axis) 

1881 

1882 return S 

1883 

1884 def get_Id(self): 

1885 """ 

1886 Get identity matrix 

1887 

1888 Returns: 

1889 sparse identity matrix 

1890 """ 

1891 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) 

1892 for axis in range(1, self.ndim): 

1893 _I = self.axes[axis].get_Id() 

1894 I = I @ self.expand_matrix_ND(_I, axis) 

1895 return I 

1896 

1897 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1898 """ 

1899 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1900 

1901 Args: 

1902 axis (int): Axis you discretized with Chebychev 

1903 

1904 Returns: 

1905 sparse matrix 

1906 """ 

1907 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1908 return self.expand_matrix_ND(C1D, axis) 

1909 

1910 def get_basis_change_matrix(self, axes=None, **kwargs): 

1911 """ 

1912 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

1913 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

1914 

1915 Args: 

1916 axes (tuple): Axes along which to change basis. 

1917 

1918 Returns: 

1919 sparse basis change matrix 

1920 """ 

1921 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1922 

1923 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) 

1924 for axis in axes[1:]: 

1925 _C = self.axes[axis].get_basis_change_matrix(**kwargs) 

1926 C = C @ self.expand_matrix_ND(_C, axis) 

1927 

1928 return C