Coverage for pySDC/helpers/spectral_helper.py: 90%

782 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-04 15:08 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5from functools import partial, wraps 

6import logging 

7 

8 

9def cache(func): 

10 """ 

11 Decorator for caching return values of functions. 

12 This is very similar to `functools.cache`, but without the memory leaks (see 

13 https://docs.astral.sh/ruff/rules/cached-instance-method/). 

14 

15 Example: 

16 

17 .. code-block:: python 

18 

19 num_calls = 0 

20 

21 @cache 

22 def increment(x): 

23 num_calls += 1 

24 return x + 1 

25 

26 increment(0) # returns 1, num_calls = 1 

27 increment(1) # returns 2, num_calls = 2 

28 increment(0) # returns 1, num_calls = 2 

29 

30 

31 Args: 

32 func (function): The function you want to cache the return value of 

33 

34 Returns: 

35 return value of func 

36 """ 

37 attr_cache = f"_{func.__name__}_cache" 

38 

39 @wraps(func) 

40 def wrapper(self, *args, **kwargs): 

41 if not hasattr(self, attr_cache): 

42 setattr(self, attr_cache, {}) 

43 

44 cache = getattr(self, attr_cache) 

45 

46 key = (args, frozenset(kwargs.items())) 

47 if key in cache: 

48 return cache[key] 

49 result = func(self, *args, **kwargs) 

50 cache[key] = result 

51 return result 

52 

53 return wrapper 

54 

55 

56class SpectralHelper1D: 

57 """ 

58 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

59 all bases need to have. 

60 

61 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

62 the code for GPUs. 

63 

64 Attributes: 

65 N (int): Resolution 

66 x0 (float): Coordinate of left boundary 

67 x1 (float): Coordinate of right boundary 

68 L (float): Length of the domain 

69 useGPU (bool): Whether to use GPUs 

70 

71 """ 

72 

73 fft_lib = scipy.fft 

74 sparse_lib = scipy.sparse 

75 linalg = scipy.sparse.linalg 

76 xp = np 

77 distributable = False 

78 

79 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False): 

80 """ 

81 Constructor 

82 

83 Args: 

84 N (int): Resolution 

85 x0 (float): Coordinate of left boundary 

86 x1 (float): Coordinate of right boundary 

87 useGPU (bool): Whether to use GPUs 

88 useFFTW (bool): Whether to use FFTW for the transforms 

89 """ 

90 self.N = N 

91 self.x0 = x0 

92 self.x1 = x1 

93 self.L = x1 - x0 

94 self.useGPU = useGPU 

95 self.plans = {} 

96 self.logger = logging.getLogger(name=type(self).__name__) 

97 

98 if useGPU: 

99 self.setup_GPU() 

100 else: 

101 self.setup_CPU(useFFTW=useFFTW) 

102 

103 if useGPU and useFFTW: 

104 raise ValueError('Please run either on GPUs or with FFTW, not both!') 

105 

106 @classmethod 

107 def setup_GPU(cls): 

108 """switch to GPU modules""" 

109 import cupy as cp 

110 import cupyx.scipy.sparse as sparse_lib 

111 import cupyx.scipy.sparse.linalg as linalg 

112 import cupyx.scipy.fft as fft_lib 

113 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

114 

115 cls.xp = cp 

116 cls.sparse_lib = sparse_lib 

117 cls.linalg = linalg 

118 cls.fft_lib = fft_lib 

119 

120 @classmethod 

121 def setup_CPU(cls, useFFTW=False): 

122 """switch to CPU modules""" 

123 

124 cls.xp = np 

125 cls.sparse_lib = scipy.sparse 

126 cls.linalg = scipy.sparse.linalg 

127 

128 if useFFTW: 

129 from mpi4py_fft import fftw 

130 

131 cls.fft_backend = 'fftw' 

132 cls.fft_lib = fftw 

133 else: 

134 cls.fft_backend = 'scipy' 

135 cls.fft_lib = scipy.fft 

136 

137 cls.fft_comm_backend = 'MPI' 

138 cls.dtype = mesh 

139 

140 def get_Id(self): 

141 """ 

142 Get identity matrix 

143 

144 Returns: 

145 sparse diagonal identity matrix 

146 """ 

147 return self.sparse_lib.eye(self.N) 

148 

149 def get_zero(self): 

150 """ 

151 Get a matrix with all zeros of the correct size. 

152 

153 Returns: 

154 sparse matrix with zeros everywhere 

155 """ 

156 return 0 * self.get_Id() 

157 

158 def get_differentiation_matrix(self): 

159 raise NotImplementedError() 

160 

161 def get_integration_matrix(self): 

162 raise NotImplementedError() 

163 

164 def get_wavenumbers(self): 

165 """ 

166 Get the grid in spectral space 

167 """ 

168 raise NotImplementedError 

169 

170 def get_empty_operator_matrix(self, S, O): 

171 """ 

172 Return a matrix of operators to be filled with the connections between the solution components. 

173 

174 Args: 

175 S (int): Number of components in the solution 

176 O (sparse matrix): Zero matrix used for initialization 

177 

178 Returns: 

179 list of lists containing sparse zeros 

180 """ 

181 return [[O for _ in range(S)] for _ in range(S)] 

182 

183 def get_basis_change_matrix(self, *args, **kwargs): 

184 """ 

185 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

186 between the various bases. 

187 

188 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

189 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

190 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

191 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

192 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

193 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

194 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

195 

196 Returns: 

197 sparse bases change matrix 

198 """ 

199 return self.sparse_lib.eye(self.N) 

200 

201 def get_BC(self, kind): 

202 """ 

203 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

204 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

205 the matrix are replaced by the boundary conditions as obtained by this method. 

206 

207 Args: 

208 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

209 individual 1D bases for what is implemented 

210 

211 Returns: 

212 self.xp.array: Boundary condition 

213 """ 

214 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

215 

216 def get_filter_matrix(self, kmin=0, kmax=None): 

217 """ 

218 Get a bandpass filter. 

219 

220 Args: 

221 kmin (int): Lower limit of the bandpass filter 

222 kmax (int): Upper limit of the bandpass filter 

223 

224 Returns: 

225 sparse matrix 

226 """ 

227 

228 k = abs(self.get_wavenumbers()) 

229 

230 kmax = max(k) if kmax is None else kmax 

231 

232 mask = self.xp.logical_or(k >= kmax, k < kmin) 

233 

234 if self.useGPU: 

235 Id = self.get_Id().get() 

236 else: 

237 Id = self.get_Id() 

238 F = Id.tolil() 

239 F[:, mask] = 0 

240 return F.tocsc() 

241 

242 def get_1dgrid(self): 

243 """ 

244 Get the grid in physical space 

245 

246 Returns: 

247 self.xp.array: Grid 

248 """ 

249 raise NotImplementedError 

250 

251 

252class ChebychevHelper(SpectralHelper1D): 

253 """ 

254 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

255 between physical and spectral space by discrete cosine transform. 

256 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

257 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

258 be formulated in first order formulation. 

259 

260 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068). 

261 """ 

262 

263 def __init__(self, *args, x0=-1, x1=1, **kwargs): 

264 """ 

265 Constructor. 

266 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

267 may choose to run on GPUs via the `useGPU` argument. 

268 

269 Args: 

270 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

271 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

272 """ 

273 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

274 self.lin_trf_fac = (x1 - x0) / 2 

275 self.lin_trf_off = (x1 + x0) / 2 

276 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

277 

278 self.norm = self.get_norm() 

279 

280 def get_1dgrid(self): 

281 ''' 

282 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

283 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

284 need to do an affine transformation before any Chebychev business. 

285 

286 Returns: 

287 numpy.ndarray: 1D grid 

288 ''' 

289 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

290 

291 def get_wavenumbers(self): 

292 """Get the domain in spectral space""" 

293 return self.xp.arange(self.N) 

294 

295 @cache 

296 def get_conv(self, name, N=None): 

297 ''' 

298 Get conversion matrix between different kinds of polynomials. The supported kinds are 

299 - T: Chebychev polynomials of first kind 

300 - U: Chebychev polynomials of second kind 

301 - D: Dirichlet recombination. 

302 

303 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

304 Once generates matrices are cached. So feel free to call the method as often as you like. 

305 

306 Args: 

307 name (str): Conversion code, e.g. 'T2U' 

308 N (int): Size of the matrix (optional) 

309 

310 Returns: 

311 scipy.sparse: Sparse conversion matrix 

312 ''' 

313 N = N if N else self.N 

314 sp = self.sparse_lib 

315 

316 def get_forward_conv(name): 

317 if name == 'T2U': 

318 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0 

319 mat[:, 0] *= 2 

320 elif name == 'D2T': 

321 mat = sp.eye(N) - sp.eye(N, k=2) 

322 elif name[0] == name[-1]: 

323 mat = self.sparse_lib.eye(self.N) 

324 else: 

325 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

326 return mat 

327 

328 try: 

329 mat = get_forward_conv(name) 

330 except NotImplementedError as E: 

331 try: 

332 fwd = get_forward_conv(name[::-1]) 

333 import scipy.sparse as sp 

334 

335 if self.sparse_lib == sp: 

336 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

337 else: 

338 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

339 except NotImplementedError: 

340 raise NotImplementedError from E 

341 

342 return mat 

343 

344 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

345 """ 

346 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

347 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

348 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

349 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

350 

351 Args: 

352 conv (str): Conversion code, i.e. T2U 

353 

354 Returns: 

355 Sparse conversion matrix 

356 """ 

357 return self.get_conv(conv) 

358 

359 def get_integration_matrix(self, lbnd=0): 

360 """ 

361 Get matrix for integration 

362 

363 Args: 

364 lbnd (float): Lower bound for integration, only 0 is currently implemented 

365 

366 Returns: 

367 Sparse integration matrix 

368 """ 

369 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

370 n = self.xp.arange(self.N) 

371 if lbnd == 0: 

372 S = S.tocsc() 

373 S[0, 1::2] = ( 

374 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

375 * (-1) ** (self.xp.arange(self.N // 2)) 

376 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

377 ) * self.lin_trf_fac 

378 else: 

379 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

380 return S 

381 

382 def get_differentiation_matrix(self, p=1): 

383 ''' 

384 Keep in mind that the T2T differentiation matrix is dense. 

385 

386 Args: 

387 p (int): Derivative you want to compute 

388 

389 Returns: 

390 numpy.ndarray: Differentiation matrix 

391 ''' 

392 D = self.xp.zeros((self.N, self.N)) 

393 for j in range(self.N): 

394 for k in range(j): 

395 D[k, j] = 2 * j * ((j - k) % 2) 

396 

397 D[0, :] /= 2 

398 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

399 

400 @cache 

401 def get_norm(self, N=None): 

402 ''' 

403 Get normalization for converting Chebychev coefficients and DCT 

404 

405 Args: 

406 N (int, optional): Resolution 

407 

408 Returns: 

409 self.xp.array: Normalization 

410 ''' 

411 N = self.N if N is None else N 

412 norm = self.xp.ones(N) / N 

413 norm[0] /= 2 

414 return norm 

415 

416 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

417 """ 

418 DCT along axes. `kwargs` will be passed on to the FFT library. 

419 

420 Args: 

421 u: Data you want to transform 

422 axes (tuple): Axes you want to transform along 

423 

424 Returns: 

425 Data in spectral space 

426 """ 

427 axes = axes if axes else tuple(i for i in range(u.ndim)) 

428 kwargs['s'] = shape 

429 kwargs['norm'] = kwargs.get('norm', 'backward') 

430 

431 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs) 

432 for axis in axes: 

433 

434 if self.N < trf.shape[axis]: 

435 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

436 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies 

437 # in the middle. 

438 _trf = self.xp.zeros_like(trf) 

439 N = self.N 

440 N_pad = _trf.shape[axis] - N 

441 end_first_half = N // 2 + 1 

442 

443 # copy first "half" 

444 su = [slice(None)] * trf.ndim 

445 su[axis] = slice(0, end_first_half) 

446 _trf[tuple(su)] = trf[tuple(su)] 

447 

448 # copy second "half" 

449 su = [slice(None)] * u.ndim 

450 su[axis] = slice(end_first_half + N_pad, None) 

451 s_u = [slice(None)] * u.ndim 

452 s_u[axis] = slice(end_first_half, N) 

453 _trf[tuple(su)] = trf[tuple(s_u)] 

454 

455 # # copy values to be cut 

456 # su = [slice(None)] * u.ndim 

457 # su[axis] = slice(end_first_half, end_first_half + N_pad) 

458 # s_u = [slice(None)] * u.ndim 

459 # s_u[axis] = slice(-N_pad, None) 

460 # _trf[tuple(su)] = trf[tuple(s_u)] 

461 

462 trf = _trf 

463 

464 expansion = [np.newaxis for _ in u.shape] 

465 expansion[axis] = slice(0, u.shape[axis], 1) 

466 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1] 

467 norm[: self.N] = self.norm 

468 trf *= norm[(*expansion,)] 

469 return trf 

470 

471 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

472 """ 

473 Inverse DCT along axis. 

474 

475 Args: 

476 u: Data you want to transform 

477 axes (tuple): Axes you want to transform along 

478 

479 Returns: 

480 Data in physical space 

481 """ 

482 axes = axes if axes else tuple(i for i in range(u.ndim)) 

483 kwargs['s'] = shape 

484 kwargs['norm'] = kwargs.get('norm', 'backward') 

485 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False) 

486 

487 for axis in axes: 

488 

489 if self.N == u.shape[axis]: 

490 _u = u.copy() 

491 else: 

492 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

493 # added in the middle rather than the end. We need to resort this here and put the padding in the end. 

494 N = self.N 

495 _u = self.xp.zeros_like(u) 

496 

497 # copy first half 

498 su = [slice(None)] * u.ndim 

499 su[axis] = slice(0, N // 2 + 1) 

500 _u[tuple(su)] = u[tuple(su)] 

501 

502 # copy second half 

503 su = [slice(None)] * u.ndim 

504 su[axis] = slice(-(N // 2), None) 

505 s_u = [slice(None)] * u.ndim 

506 s_u[axis] = slice(N // 2, N // 2 + (N // 2)) 

507 _u[tuple(s_u)] = u[tuple(su)] 

508 

509 if N % 2 == 0: 

510 su = [slice(None)] * u.ndim 

511 su[axis] = N // 2 

512 _u[tuple(su)] *= 2 

513 

514 # generate norm 

515 expansion = [np.newaxis for _ in u.shape] 

516 expansion[axis] = slice(0, u.shape[axis], 1) 

517 norm = self.xp.ones(_u.shape[axis]) 

518 norm[: self.N] = self.norm 

519 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N 

520 

521 _u /= norm[(*expansion,)] 

522 

523 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs) 

524 

525 def get_BC(self, kind, **kwargs): 

526 """ 

527 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

528 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

529 set the BC. 

530 

531 Args: 

532 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

533 """ 

534 if kind.lower() == 'integral': 

535 return self.get_integ_BC_row(**kwargs) 

536 elif kind.lower() == 'dirichlet': 

537 return self.get_Dirichlet_BC_row(**kwargs) 

538 elif kind.lower() == 'neumann': 

539 return self.get_Neumann_BC_row(**kwargs) 

540 else: 

541 return super().get_BC(kind) 

542 

543 def get_integ_BC_row(self): 

544 """ 

545 Get a row for generating integral BCs with T polynomials. 

546 It returns the values of the integrals of T polynomials over the entire interval. 

547 

548 Returns: 

549 self.xp.ndarray: Row to put into a matrix 

550 """ 

551 n = self.xp.arange(self.N) + 1 

552 me = self.xp.zeros_like(n).astype(float) 

553 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

554 me[0] = 2.0 

555 return me 

556 

557 def get_Dirichlet_BC_row(self, x): 

558 """ 

559 Get a row for generating Dirichlet BCs at x with T polynomials. 

560 It returns the values of the T polynomials at x. 

561 

562 Args: 

563 x (float): Position of the boundary condition 

564 

565 Returns: 

566 self.xp.ndarray: Row to put into a matrix 

567 """ 

568 if x == -1: 

569 return (-1) ** self.xp.arange(self.N) 

570 elif x == 1: 

571 return self.xp.ones(self.N) 

572 elif x == 0: 

573 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

574 n[2::4] *= -1 

575 return n 

576 else: 

577 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

578 

579 def get_Neumann_BC_row(self, x): 

580 """ 

581 Get a row for generating Neumann BCs at x with T polynomials. 

582 

583 Args: 

584 x (float): Position of the boundary condition 

585 

586 Returns: 

587 self.xp.ndarray: Row to put into a matrix 

588 """ 

589 n = self.xp.arange(self.N, dtype='D') 

590 nn = n**2 

591 if x == -1: 

592 me = nn 

593 me[1:] *= (-1) ** n[:-1] 

594 return me 

595 elif x == 1: 

596 return nn 

597 else: 

598 raise NotImplementedError(f'Don\'t know how to generate Neumann BC\'s at {x=}!') 

599 

600 def get_Dirichlet_recombination_matrix(self): 

601 ''' 

602 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

603 This makes for a good right preconditioner. 

604 

605 Returns: 

606 scipy.sparse: Sparse conversion matrix 

607 ''' 

608 N = self.N 

609 sp = self.sparse_lib 

610 

611 return sp.eye(N) - sp.eye(N, k=2) 

612 

613 

614class UltrasphericalHelper(ChebychevHelper): 

615 """ 

616 This implementation follows https://doi.org/10.1137/120865458. 

617 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

618 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

619 There are also conversion operators from one derivative basis to the next that are sparse. 

620 

621 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

622 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

623 """ 

624 

625 def get_differentiation_matrix(self, p=1): 

626 """ 

627 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

628 

629 Args: 

630 p (int): Order of the derivative 

631 

632 Returns: 

633 sparse differentiation matrix 

634 """ 

635 sp = self.sparse_lib 

636 xp = self.xp 

637 N = self.N 

638 l = p 

639 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

640 

641 def get_S(self, lmbda): 

642 """ 

643 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

644 https://doi.org/10.1137/120865458. 

645 

646 Args: 

647 lmbda (int): Ingoing derivative base 

648 

649 Returns: 

650 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

651 """ 

652 N = self.N 

653 

654 if lmbda == 0: 

655 sp = scipy.sparse 

656 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil() 

657 mat[:, 0] *= 2 

658 else: 

659 sp = self.sparse_lib 

660 xp = self.xp 

661 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

662 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

663 ) 

664 

665 return self.sparse_lib.csc_matrix(mat) 

666 

667 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

668 """ 

669 Get a conversion matrix from derivative base `p_in` to `p_out`. 

670 

671 Args: 

672 p_out (int): Resulting derivative base 

673 p_in (int): Ingoing derivative base 

674 """ 

675 mat_fwd = self.sparse_lib.eye(self.N) 

676 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

677 mat_fwd = self.get_S(i) @ mat_fwd 

678 

679 if p_out > p_in: 

680 return mat_fwd 

681 

682 else: 

683 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

684 import scipy.sparse as sp 

685 

686 if self.useGPU: 

687 mat_fwd = mat_fwd.get() 

688 

689 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

690 

691 return self.sparse_lib.csc_matrix(mat_bck) 

692 

693 def get_integration_matrix(self): 

694 """ 

695 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

696 integration constant such that integration starts from x=-1. 

697 

698 Example: 

699 

700 .. code-block:: python 

701 

702 import numpy as np 

703 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

704 

705 N = 4 

706 helper = UltrasphericalHelper(N) 

707 coeffs = np.random.random(N) 

708 coeffs[-1] = 0 

709 

710 poly = np.polynomial.Chebyshev(coeffs) 

711 

712 S = helper.get_integration_matrix() 

713 U_hat = S @ coeffs 

714 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

715 

716 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

717 

718 Returns: 

719 sparse integration matrix 

720 """ 

721 return ( 

722 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

723 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

724 * self.lin_trf_fac 

725 ) 

726 

727 def get_integration_constant(self, u_hat, axis): 

728 """ 

729 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

730 

731 Args: 

732 u_hat: Solution in spectral space 

733 axis: Axis you want to integrate over 

734 

735 Returns: 

736 Integration constant, has one less dimension than `u_hat` 

737 """ 

738 slices = [ 

739 None, 

740 ] * u_hat.ndim 

741 slices[axis] = slice(1, u_hat.shape[axis]) 

742 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

743 

744 

745class FFTHelper(SpectralHelper1D): 

746 distributable = True 

747 

748 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

749 """ 

750 Constructor. 

751 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

752 may choose to run on GPUs via the `useGPU` argument. 

753 

754 Args: 

755 x0 (float, optional): Coordinate of left boundary 

756 x1 (float, optional): Coordinate of right boundary 

757 """ 

758 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

759 

760 def get_1dgrid(self): 

761 """ 

762 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

763 """ 

764 dx = self.L / self.N 

765 return self.xp.arange(self.N) * dx + self.x0 

766 

767 def get_wavenumbers(self): 

768 """ 

769 Be careful that this ordering is very unintuitive. 

770 """ 

771 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

772 

773 def get_differentiation_matrix(self, p=1): 

774 """ 

775 This matrix is diagonal, allowing to invert concurrently. 

776 

777 Args: 

778 p (int): Order of the derivative 

779 

780 Returns: 

781 sparse differentiation matrix 

782 """ 

783 k = self.get_wavenumbers() 

784 

785 if self.useGPU: 

786 if p > 1: 

787 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

788 from scipy.sparse.linalg import matrix_power 

789 

790 D = self.sparse_lib.diags(1j * k).get() 

791 return self.sparse_lib.csc_matrix(matrix_power(D, p)) 

792 else: 

793 return self.sparse_lib.diags(1j * k) 

794 else: 

795 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

796 

797 def get_integration_matrix(self, p=1): 

798 """ 

799 Get integration matrix to compute `p`-th integral over the entire domain. 

800 

801 Args: 

802 p (int): Order of integral you want to compute 

803 

804 Returns: 

805 sparse integration matrix 

806 """ 

807 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

808 k[0] = 1j * self.L 

809 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

810 

811 def get_plan(self, u, forward, *args, **kwargs): 

812 if self.fft_lib.__name__ == 'mpi4py_fft.fftw': 

813 if 'axes' in kwargs.keys(): 

814 kwargs['axes'] = tuple(kwargs['axes']) 

815 key = (forward, u.shape, args, *(me for me in kwargs.values())) 

816 if key in self.plans.keys(): 

817 return self.plans[key] 

818 else: 

819 self.logger.debug(f'Generating FFT plan for {key=}') 

820 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs) 

821 self.plans[key] = transform 

822 

823 return self.plans[key] 

824 else: 

825 if forward: 

826 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward')) 

827 else: 

828 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward')) 

829 

830 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

831 """ 

832 FFT along axes. `kwargs` are passed on to the FFT library. 

833 

834 Args: 

835 u: Data you want to transform 

836 axes (tuple): Axes you want to transform over 

837 

838 Returns: 

839 transformed data 

840 """ 

841 axes = axes if axes else tuple(i for i in range(u.ndim)) 

842 kwargs['s'] = shape 

843 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs) 

844 return plan(u, *args, axes=axes, **kwargs) 

845 

846 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

847 """ 

848 Inverse FFT. 

849 

850 Args: 

851 u: Data you want to transform 

852 axes (tuple): Axes over which to transform 

853 

854 Returns: 

855 transformed data 

856 """ 

857 axes = axes if axes else tuple(i for i in range(u.ndim)) 

858 kwargs['s'] = shape 

859 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs) 

860 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes]) 

861 

862 def get_BC(self, kind): 

863 """ 

864 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

865 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

866 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

867 

868 Args: 

869 kind ('integral' or 'nyquist'): Kind of BC 

870 

871 Returns: 

872 self.xp.ndarray: Boundary condition row 

873 """ 

874 if kind.lower() == 'integral': 

875 return self.get_integ_BC_row() 

876 elif kind.lower() == 'nyquist': 

877 assert ( 

878 self.N % 2 == 0 

879 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

880 BC = self.xp.zeros(self.N) 

881 BC[self.get_Nyquist_mode_index()] = 1 

882 return BC 

883 else: 

884 return super().get_BC(kind) 

885 

886 def get_Nyquist_mode_index(self): 

887 """ 

888 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

889 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

890 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

891 after. 

892 

893 Returns: 

894 int: Index of the Nyquist mode 

895 """ 

896 k = self.get_wavenumbers() 

897 Nyquist_mode = min(k) 

898 return self.xp.where(k == Nyquist_mode)[0][0] 

899 

900 def get_integ_BC_row(self): 

901 """ 

902 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

903 """ 

904 me = self.xp.zeros(self.N) 

905 me[0] = self.L / self.N 

906 return me 

907 

908 

909class SpectralHelper: 

910 """ 

911 This class has three functions: 

912 - Easily assemble matrices containing multiple equations 

913 - Direct product of 1D bases to solve problems in more dimensions 

914 - Distribute the FFTs to facilitate concurrency. 

915 

916 Attributes: 

917 comm (mpi4py.Intracomm): MPI communicator 

918 debug (bool): Perform additional checks at extra computational cost 

919 useGPU (bool): Whether to use GPUs 

920 axes (list): List of 1D bases 

921 components (list): List of strings of the names of components in the equations 

922 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

923 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

924 BCs (sparse matrix): Matrix containing only the BCs 

925 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

926 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

927 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

928 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

929 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

930 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

931 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

932 init (tuple): This is the same `init` that is used throughout the problem classes 

933 init_forward (tuple): This is the equivalent of `init` in spectral space 

934 """ 

935 

936 xp = np 

937 fft_lib = scipy.fft 

938 sparse_lib = scipy.sparse 

939 linalg = scipy.sparse.linalg 

940 dtype = mesh 

941 fft_backend = 'scipy' 

942 fft_comm_backend = 'MPI' 

943 

944 @classmethod 

945 def setup_GPU(cls): 

946 """switch to GPU modules""" 

947 import cupy as cp 

948 import cupyx.scipy.sparse as sparse_lib 

949 import cupyx.scipy.sparse.linalg as linalg 

950 import cupyx.scipy.fft as fft_lib 

951 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

952 

953 cls.xp = cp 

954 cls.sparse_lib = sparse_lib 

955 cls.linalg = linalg 

956 

957 cls.fft_lib = fft_lib 

958 cls.fft_backend = 'cupyx-scipy' 

959 cls.fft_comm_backend = 'NCCL' 

960 

961 cls.dtype = cupy_mesh 

962 

963 @classmethod 

964 def setup_CPU(cls, useFFTW=False): 

965 """switch to CPU modules""" 

966 

967 cls.xp = np 

968 cls.sparse_lib = scipy.sparse 

969 cls.linalg = scipy.sparse.linalg 

970 

971 if useFFTW: 

972 from mpi4py_fft import fftw 

973 

974 cls.fft_backend = 'fftw' 

975 cls.fft_lib = fftw 

976 else: 

977 cls.fft_backend = 'scipy' 

978 cls.fft_lib = scipy.fft 

979 

980 cls.fft_comm_backend = 'MPI' 

981 cls.dtype = mesh 

982 

983 def __init__(self, comm=None, useGPU=False, debug=False): 

984 """ 

985 Constructor 

986 

987 Args: 

988 comm (mpi4py.Intracomm): MPI communicator 

989 useGPU (bool): Whether to use GPUs 

990 debug (bool): Perform additional checks at extra computational cost 

991 """ 

992 self.comm = comm 

993 self.debug = debug 

994 self.useGPU = useGPU 

995 

996 if useGPU: 

997 self.setup_GPU() 

998 else: 

999 self.setup_CPU() 

1000 

1001 self.axes = [] 

1002 self.components = [] 

1003 

1004 self.full_BCs = [] 

1005 self.BC_mat = None 

1006 self.BCs = None 

1007 

1008 self.fft_cache = {} 

1009 self.fft_dealias_shape_cache = {} 

1010 

1011 self.logger = logging.getLogger(name='Spectral Discretization') 

1012 if debug: 

1013 self.logger.setLevel(logging.DEBUG) 

1014 

1015 @property 

1016 def u_init(self): 

1017 """ 

1018 Get empty data container in physical space 

1019 """ 

1020 return self.dtype(self.init) 

1021 

1022 @property 

1023 def u_init_forward(self): 

1024 """ 

1025 Get empty data container in spectral space 

1026 """ 

1027 return self.dtype(self.init_forward) 

1028 

1029 @property 

1030 def u_init_physical(self): 

1031 """ 

1032 Get empty data container in physical space 

1033 """ 

1034 return self.dtype(self.init_physical) 

1035 

1036 @property 

1037 def shape(self): 

1038 """ 

1039 Get shape of individual solution component 

1040 """ 

1041 return self.init[0][1:] 

1042 

1043 @property 

1044 def ndim(self): 

1045 return len(self.axes) 

1046 

1047 @property 

1048 def ncomponents(self): 

1049 return len(self.components) 

1050 

1051 @property 

1052 def V(self): 

1053 """ 

1054 Get domain volume 

1055 """ 

1056 return np.prod([me.L for me in self.axes]) 

1057 

1058 def add_axis(self, base, *args, **kwargs): 

1059 """ 

1060 Add an axis to the domain by deciding on suitable 1D base. 

1061 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

1062 bases for possible arguments. 

1063 

1064 Args: 

1065 base (str): 1D spectral method 

1066 """ 

1067 kwargs['useGPU'] = self.useGPU 

1068 

1069 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

1070 self.axes.append(ChebychevHelper(*args, **kwargs)) 

1071 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

1072 self.axes.append(FFTHelper(*args, **kwargs)) 

1073 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

1074 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

1075 else: 

1076 raise NotImplementedError(f'{base=!r} is not implemented!') 

1077 self.axes[-1].xp = self.xp 

1078 self.axes[-1].sparse_lib = self.sparse_lib 

1079 

1080 def add_component(self, name): 

1081 """ 

1082 Add solution component(s). 

1083 

1084 Args: 

1085 name (str or list of strings): Name(s) of component(s) 

1086 """ 

1087 if type(name) in [list, tuple]: 

1088 for me in name: 

1089 self.add_component(me) 

1090 elif type(name) in [str]: 

1091 if name in self.components: 

1092 raise Exception(f'{name=!r} is already added to this problem!') 

1093 self.components.append(name) 

1094 else: 

1095 raise NotImplementedError 

1096 

1097 def index(self, name): 

1098 """ 

1099 Get the index of component `name`. 

1100 

1101 Args: 

1102 name (str or list of strings): Name(s) of component(s) 

1103 

1104 Returns: 

1105 int: Index of the component 

1106 """ 

1107 if type(name) in [str, int]: 

1108 return self.components.index(name) 

1109 elif type(name) in [list, tuple]: 

1110 return (self.index(me) for me in name) 

1111 else: 

1112 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

1113 

1114 def get_empty_operator_matrix(self, diag=False): 

1115 """ 

1116 Return a matrix of operators to be filled with the connections between the solution components. 

1117 

1118 Args: 

1119 diag (bool): Whether operator is block-diagonal 

1120 

1121 Returns: 

1122 list containing sparse zeros 

1123 """ 

1124 S = len(self.components) 

1125 O = self.get_Id() * 0 

1126 if diag: 

1127 return [O for _ in range(S)] 

1128 else: 

1129 return [[O for _ in range(S)] for _ in range(S)] 

1130 

1131 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

1132 """ 

1133 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

1134 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

1135 Typically, the last line that does not contain a BC is the best choice. 

1136 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1137 

1138 Args: 

1139 axis (int): Axis you want to add the BC to 

1140 kind (str): kind of BC, e.g. Dirichlet 

1141 line (int): Line you want the BC to go in 

1142 scalar (bool): Put the BC in all space positions in the other direction 

1143 

1144 Returns: 

1145 sparse matrix containing the BC 

1146 """ 

1147 sp = scipy.sparse 

1148 

1149 base = self.axes[axis] 

1150 

1151 BC = sp.eye(base.N).tolil() * 0 

1152 if self.useGPU: 

1153 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1154 else: 

1155 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1156 

1157 ndim = len(self.axes) 

1158 if ndim == 1: 

1159 mat = self.sparse_lib.csc_matrix(BC) 

1160 elif ndim == 2: 

1161 axis2 = (axis + 1) % ndim 

1162 

1163 if scalar: 

1164 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1165 else: 

1166 _Id = self.axes[axis2].get_Id() 

1167 

1168 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1169 

1170 mats = [ 

1171 None, 

1172 ] * ndim 

1173 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1174 mats[axis2] = Id 

1175 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) 

1176 elif ndim == 3: 

1177 mats = [ 

1178 None, 

1179 ] * ndim 

1180 

1181 for ax in range(ndim): 

1182 if ax == axis: 

1183 continue 

1184 

1185 if scalar: 

1186 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1))) 

1187 else: 

1188 _Id = self.axes[ax].get_Id() 

1189 

1190 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax) 

1191 

1192 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1193 

1194 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) 

1195 else: 

1196 raise NotImplementedError( 

1197 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!' 

1198 ) 

1199 mat = self.eliminate_zeros(mat) 

1200 return mat 

1201 

1202 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1203 """ 

1204 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1205 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1206 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1207 

1208 Args: 

1209 component (str): Name of the component the BC should act on 

1210 equation (str): Name of the equation for the component you want to put the BC in 

1211 axis (int): Axis you want to add the BC to 

1212 kind (str): kind of BC, e.g. Dirichlet 

1213 v: Value of the BC 

1214 line (int): Line you want the BC to go in 

1215 scalar (bool): Put the BC in all space positions in the other direction 

1216 """ 

1217 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1218 _BC = self.eliminate_zeros(_BC) 

1219 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1220 

1221 if scalar: 

1222 slices = [self.index(equation)] + [ 

1223 0, 

1224 ] * self.ndim 

1225 slices[axis + 1] = line 

1226 else: 

1227 slices = ( 

1228 [self.index(equation)] 

1229 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1230 + [line] 

1231 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1232 ) 

1233 N = self.axes[axis].N 

1234 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]: 

1235 self.BC_rhs_mask[(*slices,)] = False 

1236 

1237 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1238 """ 

1239 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1240 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1241 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1242 

1243 Args: 

1244 component (str): Name of the component the BC should act on 

1245 equation (str): Name of the equation for the component you want to put the BC in 

1246 axis (int): Axis you want to add the BC to 

1247 kind (str): kind of BC, e.g. Dirichlet 

1248 v: Value of the BC 

1249 line (int): Line you want the BC to go in 

1250 scalar (bool): Put the BC in all space positions in the other direction 

1251 """ 

1252 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1253 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1254 self.full_BCs += [ 

1255 { 

1256 'component': component, 

1257 'equation': equation, 

1258 'axis': axis, 

1259 'kind': kind, 

1260 'v': v, 

1261 'line': line, 

1262 'scalar': scalar, 

1263 **kwargs, 

1264 } 

1265 ] 

1266 

1267 if scalar: 

1268 slices = [self.index(equation)] + [ 

1269 0, 

1270 ] * self.ndim 

1271 slices[axis + 1] = line 

1272 if self.comm: 

1273 if self.comm.rank == 0: 

1274 self.BC_rhs_mask[(*slices,)] = True 

1275 else: 

1276 self.BC_rhs_mask[(*slices,)] = True 

1277 else: 

1278 slices = [self.index(equation), *self.global_slice(True)] 

1279 N = self.axes[axis].N 

1280 if (N + line) % N in self.get_indices(True)[axis]: 

1281 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1282 self.BC_rhs_mask[(*slices,)] = True 

1283 

1284 def setup_BCs(self): 

1285 """ 

1286 Convert the list of lists of BCs to the boundary condition operator. 

1287 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1288 condition. This method sets up a suitable sparse matrix to do this. 

1289 """ 

1290 sp = self.sparse_lib 

1291 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1292 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1293 

1294 diags = self.xp.ones(self.BCs.shape[0]) 

1295 diags[self.BC_zero_index] = 0 

1296 self.BC_line_zero_matrix = sp.diags(diags) 

1297 

1298 # prepare BCs in spectral space to easily add to the RHS 

1299 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1300 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1301 

1302 def check_BCs(self, u): 

1303 """ 

1304 Check that the solution satisfies the boundary conditions 

1305 

1306 Args: 

1307 u: The solution you want to check 

1308 """ 

1309 assert self.ndim < 3 

1310 for axis in range(self.ndim): 

1311 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1312 

1313 if len(BCs) > 0: 

1314 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1315 for BC in BCs: 

1316 kwargs = { 

1317 key: value 

1318 for key, value in BC.items() 

1319 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1320 } 

1321 

1322 if axis == 0: 

1323 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1324 elif axis == 1: 

1325 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1326 want = BC['v'] 

1327 assert self.xp.allclose( 

1328 get, want 

1329 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1330 

1331 def put_BCs_in_matrix(self, A): 

1332 """ 

1333 Put the boundary conditions in a matrix by replacing rows with BCs. 

1334 """ 

1335 return self.BC_line_zero_matrix @ A + self.BCs 

1336 

1337 def put_BCs_in_rhs_hat(self, rhs_hat): 

1338 """ 

1339 Put the BCs in the right hand side in spectral space for solving. 

1340 This function needs no transforms and caches a mask for faster subsequent use. 

1341 

1342 Args: 

1343 rhs_hat: Right hand side in spectral space 

1344 

1345 Returns: 

1346 rhs in spectral space with BCs 

1347 """ 

1348 if not hasattr(self, '_rhs_hat_zero_mask'): 

1349 """ 

1350 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1351 by the boundary conditions. The mask is then cached. 

1352 """ 

1353 self._rhs_hat_zero_mask = self.newDistArray().astype(bool) 

1354 

1355 for axis in range(self.ndim): 

1356 for bc in self.full_BCs: 

1357 if axis == bc['axis']: 

1358 slices = [self.index(bc['equation']), *self.global_slice(True)] 

1359 N = self.axes[axis].N 

1360 line = bc['line'] 

1361 if (N + line) % N in self.get_indices(True)[axis]: 

1362 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1363 self._rhs_hat_zero_mask[(*slices,)] = True 

1364 

1365 rhs_hat[self._rhs_hat_zero_mask] = 0 

1366 return rhs_hat + self.rhs_BCs_hat 

1367 

1368 def put_BCs_in_rhs(self, rhs): 

1369 """ 

1370 Put the BCs in the right hand side for solving. 

1371 This function will transform along each axis individually and add all BCs in that axis. 

1372 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1373 

1374 Args: 

1375 rhs: Right hand side in physical space 

1376 

1377 Returns: 

1378 rhs in physical space with BCs 

1379 """ 

1380 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1381 

1382 ndim = self.ndim 

1383 

1384 for axis in range(ndim): 

1385 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1386 

1387 for bc in self.full_BCs: 

1388 

1389 if axis == bc['axis']: 

1390 _slice = [self.index(bc['equation']), *self.global_slice(True)] 

1391 

1392 N = self.axes[axis].N 

1393 line = bc['line'] 

1394 if (N + line) % N in self.get_indices(True)[axis]: 

1395 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1396 _rhs_hat[(*_slice,)] = bc['v'] 

1397 

1398 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1399 

1400 return rhs 

1401 

1402 def add_equation_lhs(self, A, equation, relations): 

1403 """ 

1404 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1405 that you will convert to an operator later. 

1406 

1407 Example: 

1408 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1409 preconditioning: 

1410 

1411 .. code-block:: python 

1412 helper = SpectralHelper() 

1413 

1414 helper.add_axis(base='chebychev', N=8) 

1415 helper.add_component(['u', 'ux']) 

1416 helper.setup_fft() 

1417 

1418 I = helper.get_Id() 

1419 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1420 T2U = helper.get_basis_change_matrix('T2U') 

1421 

1422 L_lhs = { 

1423 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1424 'u': {'ux': -(T2U @ Dx)}, 

1425 } 

1426 

1427 operator = helper.get_empty_operator_matrix() 

1428 for line, equation in L_lhs.items(): 

1429 helper.add_equation_lhs(operator, line, equation) 

1430 

1431 L = helper.convert_operator_matrix_to_operator(operator) 

1432 

1433 Args: 

1434 A (list of lists of sparse matrices): The operator to be 

1435 equation (str): The equation of the component you want this in 

1436 relations: (dict): Relations between quantities 

1437 """ 

1438 for k, v in relations.items(): 

1439 A[self.index(equation)][self.index(k)] = v 

1440 

1441 def eliminate_zeros(self, A): 

1442 """ 

1443 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat. 

1444 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`. 

1445 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU. 

1446 

1447 Args: 

1448 A: sparse matrix to be pruned 

1449 

1450 Returns: 

1451 CSC sparse matrix 

1452 """ 

1453 if self.useGPU: 

1454 A = A.get() 

1455 A = A.tocsc() 

1456 A.eliminate_zeros() 

1457 if self.useGPU: 

1458 A = self.sparse_lib.csc_matrix(A) 

1459 return A 

1460 

1461 def convert_operator_matrix_to_operator(self, M): 

1462 """ 

1463 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1464 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1465 

1466 Args: 

1467 M (list of lists of sparse matrices): The operator to be 

1468 

1469 Returns: 

1470 sparse linear operator 

1471 """ 

1472 if len(self.components) == 1: 

1473 op = M[0][0] 

1474 else: 

1475 op = self.sparse_lib.bmat(M, format='csc') 

1476 

1477 op = self.eliminate_zeros(op) 

1478 return op 

1479 

1480 def get_wavenumbers(self): 

1481 """ 

1482 Get grid in spectral space 

1483 """ 

1484 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))] 

1485 return self.xp.meshgrid(*grids, indexing='ij') 

1486 

1487 def get_grid(self, forward_output=False): 

1488 """ 

1489 Get grid in physical space 

1490 """ 

1491 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1492 return self.xp.meshgrid(*grids, indexing='ij') 

1493 

1494 def get_indices(self, forward_output=True): 

1495 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1496 

1497 @cache 

1498 def get_pfft(self, axes=None, padding=None, grid=None): 

1499 if self.ndim == 1 or self.comm is None: 

1500 return None 

1501 from mpi4py_fft import PFFT 

1502 

1503 axes = tuple(i for i in range(self.ndim)) if axes is None else axes 

1504 padding = list(padding if padding else [1.0 for _ in range(self.ndim)]) 

1505 

1506 def no_transform(u, *args, **kwargs): 

1507 return u 

1508 

1509 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)} 

1510 for i in axes: 

1511 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform) 

1512 

1513 # "transform" all axes to ensure consistent shapes. 

1514 # Transform non-distributable axes last to ensure they are aligned 

1515 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes)) 

1516 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted( 

1517 [axis for axis in _axes if self.axes[axis].distributable] 

1518 + [axis for axis in range(self.ndim) if axis not in _axes] 

1519 ) 

1520 

1521 pfft = PFFT( 

1522 comm=self.comm, 

1523 shape=self.global_shape[1:], 

1524 axes=_axes, # TODO: control the order of the transforms better 

1525 dtype='D', 

1526 collapse=False, 

1527 backend=self.fft_backend, 

1528 comm_backend=self.fft_comm_backend, 

1529 padding=padding, 

1530 transforms=transforms, 

1531 grid=grid, 

1532 ) 

1533 return pfft 

1534 

1535 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1536 """ 

1537 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1538 

1539 Args: 

1540 axes (tuple): Axes you want to transform over 

1541 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1542 padding (tuple): Padding for dealiasing 

1543 shape (tuple): Shape of the transform 

1544 

1545 Returns: 

1546 transform 

1547 """ 

1548 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1549 shape = self.global_shape[1:] if shape is None else shape 

1550 padding = ( 

1551 [ 

1552 1, 

1553 ] 

1554 * self.ndim 

1555 if padding is None 

1556 else padding 

1557 ) 

1558 key = (axes, direction, tuple(padding), tuple(shape)) 

1559 

1560 if key not in self.fft_cache.keys(): 

1561 if self.comm is None: 

1562 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1563 

1564 if direction == 'forward': 

1565 self.fft_cache[key] = self.xp.fft.fftn 

1566 elif direction == 'backward': 

1567 self.fft_cache[key] = self.xp.fft.ifftn 

1568 elif direction == 'object': 

1569 self.fft_cache[key] = None 

1570 else: 

1571 if direction == 'object': 

1572 from mpi4py_fft import PFFT 

1573 

1574 _fft = PFFT( 

1575 comm=self.comm, 

1576 shape=shape, 

1577 axes=sorted(axes), 

1578 dtype='D', 

1579 collapse=False, 

1580 backend=self.fft_backend, 

1581 comm_backend=self.fft_comm_backend, 

1582 padding=padding, 

1583 ) 

1584 else: 

1585 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1586 

1587 if direction == 'forward': 

1588 self.fft_cache[key] = _fft.forward 

1589 elif direction == 'backward': 

1590 self.fft_cache[key] = _fft.backward 

1591 elif direction == 'object': 

1592 self.fft_cache[key] = _fft 

1593 

1594 return self.fft_cache[key] 

1595 

1596 def local_slice(self, forward_output=True): 

1597 if self.fft_obj: 

1598 return self.get_pfft().local_slice(forward_output=forward_output) 

1599 else: 

1600 return [slice(0, me.N) for me in self.axes] 

1601 

1602 def global_slice(self, forward_output=True): 

1603 if self.fft_obj: 

1604 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)] 

1605 else: 

1606 return self.local_slice(forward_output=forward_output) 

1607 

1608 def setup_fft(self, real_spectral_coefficients=False): 

1609 """ 

1610 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1611 This must also be called before setting up any BCs. 

1612 

1613 Args: 

1614 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1615 """ 

1616 if len(self.components) == 0: 

1617 self.add_component('u') 

1618 

1619 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1620 

1621 axes = tuple(i for i in range(len(self.axes))) 

1622 self.fft_obj = self.get_pfft(axes=axes) 

1623 

1624 self.init = ( 

1625 np.empty(shape=self.global_shape)[ 

1626 ( 

1627 ..., 

1628 *self.local_slice(False), 

1629 ) 

1630 ].shape, 

1631 self.comm, 

1632 np.dtype('float'), 

1633 ) 

1634 self.init_physical = ( 

1635 np.empty(shape=self.global_shape)[ 

1636 ( 

1637 ..., 

1638 *self.local_slice(False), 

1639 ) 

1640 ].shape, 

1641 self.comm, 

1642 np.dtype('float'), 

1643 ) 

1644 self.init_forward = ( 

1645 np.empty(shape=self.global_shape)[ 

1646 ( 

1647 ..., 

1648 *self.local_slice(True), 

1649 ) 

1650 ].shape, 

1651 self.comm, 

1652 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1653 ) 

1654 

1655 self.BC_mat = self.get_empty_operator_matrix() 

1656 self.BC_rhs_mask = self.newDistArray().astype(bool) 

1657 

1658 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False): 

1659 """ 

1660 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but 

1661 takes care of all the solution components in the tensor. 

1662 """ 

1663 if self.comm is None: 

1664 return self.xp.zeros(self.init[0], dtype=self.init[2]) 

1665 from mpi4py_fft.distarray import DistArray 

1666 

1667 pfft = pfft if pfft else self.get_pfft() 

1668 if pfft is None: 

1669 if forward_output: 

1670 return self.u_init_forward 

1671 else: 

1672 return self.u_init 

1673 

1674 global_shape = pfft.global_shape(forward_output) 

1675 p0 = pfft.pencil[forward_output] 

1676 if forward_output is True: 

1677 dtype = pfft.forward.output_array.dtype 

1678 else: 

1679 dtype = pfft.forward.input_array.dtype 

1680 global_shape = (self.ncomponents,) * rank + global_shape 

1681 

1682 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: 

1683 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls 

1684 else: 

1685 darraycls = DistArray 

1686 

1687 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank) 

1688 return z.v if view else z 

1689 

1690 def infer_alignment(self, u, forward_output, padding=None, **kwargs): 

1691 if self.comm is None: 

1692 return [0] 

1693 

1694 def _alignment(pfft): 

1695 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1696 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]] 

1697 return _aligned_axes 

1698 

1699 if padding is None: 

1700 pfft = self.get_pfft(**kwargs) 

1701 aligned_axes = _alignment(pfft) 

1702 else: 

1703 if self.ndim == 2: 

1704 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)] 

1705 elif self.ndim == 3: 

1706 padding_options = [ 

1707 (1.0, 1.0, padding[2]), 

1708 (1.0, padding[1], 1.0), 

1709 (padding[0], 1.0, 1.0), 

1710 (1.0, padding[1], padding[2]), 

1711 (padding[0], 1.0, padding[2]), 

1712 (padding[0], padding[1], 1.0), 

1713 padding, 

1714 (1.0, 1.0, 1.0), 

1715 ] 

1716 else: 

1717 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!') 

1718 for _padding in padding_options: 

1719 pfft = self.get_pfft(padding=_padding, **kwargs) 

1720 aligned_axes = _alignment(pfft) 

1721 if len(aligned_axes) > 0: 

1722 self.logger.debug( 

1723 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}' 

1724 ) 

1725 break 

1726 

1727 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!' 

1728 return aligned_axes 

1729 

1730 def redistribute(self, u, axis, forward_output, **kwargs): 

1731 if self.comm is None: 

1732 return u 

1733 

1734 pfft = self.get_pfft(**kwargs) 

1735 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1736 

1737 if 'Dist' in type(u).__name__ and False: 

1738 try: 

1739 u.redistribute(out=_arr) 

1740 return _arr 

1741 except AssertionError: 

1742 pass 

1743 

1744 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs) 

1745 for alignment in u_alignment: 

1746 _arr = _arr.redistribute(alignment) 

1747 if _arr.shape == u.shape: 

1748 _arr[...] = u 

1749 return _arr.redistribute(axis) 

1750 

1751 raise Exception( 

1752 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}' 

1753 ) 

1754 

1755 def transform(self, u, *args, axes=None, padding=None, **kwargs): 

1756 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1757 

1758 if pfft is None: 

1759 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1760 u_hat = u.copy() 

1761 for i in axes: 

1762 _axis = 1 + i if i >= 0 else i 

1763 u_hat = self.axes[i].transform(u_hat, axes=(_axis,)) 

1764 return u_hat 

1765 

1766 _in = self.newDistArray(pfft, forward_output=False, rank=1) 

1767 _out = self.newDistArray(pfft, forward_output=True, rank=1) 

1768 

1769 if _in.shape == u.shape: 

1770 _in[...] = u 

1771 else: 

1772 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs) 

1773 

1774 for i in range(self.ncomponents): 

1775 pfft.forward(_in[i], _out[i], normalize=False) 

1776 

1777 if padding is not None: 

1778 _out /= np.prod(padding) 

1779 return _out 

1780 

1781 def itransform(self, u, *args, axes=None, padding=None, **kwargs): 

1782 if padding is not None: 

1783 assert all( 

1784 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim) 

1785 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer' 

1786 

1787 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1788 if pfft is None: 

1789 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1790 u_hat = u.copy() 

1791 for i in axes: 

1792 _axis = 1 + i if i >= 0 else i 

1793 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,)) 

1794 return u_hat 

1795 

1796 _in = self.newDistArray(pfft, forward_output=True, rank=1) 

1797 _out = self.newDistArray(pfft, forward_output=False, rank=1) 

1798 

1799 if _in.shape == u.shape: 

1800 _in[...] = u 

1801 else: 

1802 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs) 

1803 

1804 for i in range(self.ncomponents): 

1805 pfft.backward(_in[i], _out[i], normalize=True) 

1806 

1807 if padding is not None: 

1808 _out *= np.prod(padding) 

1809 return _out 

1810 

1811 def get_local_slice_of_1D_matrix(self, M, axis): 

1812 """ 

1813 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1814 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can 

1815 use this method to get the part corresponding to the modes carried by this rank. 

1816 

1817 Args: 

1818 M (sparse matrix): Global 1D matrix you want to get the local version of 

1819 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. 

1820 

1821 Returns: 

1822 sparse local matrix 

1823 """ 

1824 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]] 

1825 

1826 def expand_matrix_ND(self, matrix, aligned): 

1827 sp = self.sparse_lib 

1828 axes = np.delete(np.arange(self.ndim), aligned) 

1829 ndim = len(axes) + 1 

1830 

1831 if ndim == 1: 

1832 mat = matrix 

1833 elif ndim == 2: 

1834 axis = axes[0] 

1835 I1D = sp.eye(self.axes[axis].N) 

1836 

1837 mats = [None] * ndim 

1838 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1839 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1840 

1841 mat = sp.kron(*mats) 

1842 elif ndim == 3: 

1843 

1844 mats = [None] * ndim 

1845 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1846 for axis in axes: 

1847 I1D = sp.eye(self.axes[axis].N) 

1848 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1849 

1850 mat = sp.kron(mats[0], sp.kron(*mats[1:])) 

1851 

1852 else: 

1853 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') 

1854 

1855 mat = self.eliminate_zeros(mat) 

1856 return mat 

1857 

1858 def get_filter_matrix(self, axis, **kwargs): 

1859 """ 

1860 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1861 admissible. 

1862 

1863 Returns: 

1864 sparse bandpass matrix 

1865 """ 

1866 if self.ndim == 1: 

1867 return self.axes[0].get_filter_matrix(**kwargs) 

1868 

1869 mats = [base.get_Id() for base in self.axes] 

1870 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1871 return self.sparse_lib.kron(*mats) 

1872 

1873 def get_differentiation_matrix(self, axes, **kwargs): 

1874 """ 

1875 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1876 

1877 Args: 

1878 axes (tuple): Axes along which to differentiate. 

1879 

1880 Returns: 

1881 sparse differentiation matrix 

1882 """ 

1883 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) 

1884 for axis in axes[1:]: 

1885 _D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1886 D = D @ self.expand_matrix_ND(_D, axis) 

1887 

1888 return D 

1889 

1890 def get_integration_matrix(self, axes): 

1891 """ 

1892 Get integration matrix to integrate along specified axis. 

1893 

1894 Args: 

1895 axes (tuple): Axes along which to integrate over. 

1896 

1897 Returns: 

1898 sparse integration matrix 

1899 """ 

1900 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) 

1901 for axis in axes[1:]: 

1902 _S = self.axes[axis].get_integration_matrix() 

1903 S = S @ self.expand_matrix_ND(_S, axis) 

1904 

1905 return S 

1906 

1907 def get_Id(self): 

1908 """ 

1909 Get identity matrix 

1910 

1911 Returns: 

1912 sparse identity matrix 

1913 """ 

1914 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) 

1915 for axis in range(1, self.ndim): 

1916 _I = self.axes[axis].get_Id() 

1917 I = I @ self.expand_matrix_ND(_I, axis) 

1918 return I 

1919 

1920 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1921 """ 

1922 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1923 

1924 Args: 

1925 axis (int): Axis you discretized with Chebychev 

1926 

1927 Returns: 

1928 sparse matrix 

1929 """ 

1930 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1931 return self.expand_matrix_ND(C1D, axis) 

1932 

1933 def get_basis_change_matrix(self, axes=None, **kwargs): 

1934 """ 

1935 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

1936 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

1937 

1938 Args: 

1939 axes (tuple): Axes along which to change basis. 

1940 

1941 Returns: 

1942 sparse basis change matrix 

1943 """ 

1944 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1945 

1946 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) 

1947 for axis in axes[1:]: 

1948 _C = self.axes[axis].get_basis_change_matrix(**kwargs) 

1949 C = C @ self.expand_matrix_ND(_C, axis) 

1950 

1951 return C