Coverage for pySDC/helpers/spectral_helper.py: 92%

773 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5 

6 

7class SpectralHelper1D: 

8 """ 

9 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

10 all bases need to have. 

11 

12 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

13 the code for GPUs. 

14 

15 Attributes: 

16 N (int): Resolution 

17 x0 (float): Coordinate of left boundary 

18 x1 (float): Coordinate of right boundary 

19 L (float): Length of the domain 

20 useGPU (bool): Whether to use GPUs 

21 

22 """ 

23 

24 fft_lib = scipy.fft 

25 sparse_lib = scipy.sparse 

26 linalg = scipy.sparse.linalg 

27 xp = np 

28 

29 def __init__(self, N, x0=None, x1=None, useGPU=False): 

30 """ 

31 Constructor 

32 

33 Args: 

34 N (int): Resolution 

35 x0 (float): Coordinate of left boundary 

36 x1 (float): Coordinate of right boundary 

37 useGPU (bool): Whether to use GPUs 

38 """ 

39 self.N = N 

40 self.x0 = x0 

41 self.x1 = x1 

42 self.L = x1 - x0 

43 self.useGPU = useGPU 

44 

45 if useGPU: 

46 self.setup_GPU() 

47 

48 @classmethod 

49 def setup_GPU(cls): 

50 """switch to GPU modules""" 

51 import cupy as cp 

52 import cupyx.scipy.sparse as sparse_lib 

53 import cupyx.scipy.sparse.linalg as linalg 

54 import cupyx.scipy.fft as fft_lib 

55 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

56 

57 cls.xp = cp 

58 cls.sparse_lib = sparse_lib 

59 cls.linalg = linalg 

60 cls.fft_lib = fft_lib 

61 

62 def get_Id(self): 

63 """ 

64 Get identity matrix 

65 

66 Returns: 

67 sparse diagonal identity matrix 

68 """ 

69 return self.sparse_lib.eye(self.N) 

70 

71 def get_zero(self): 

72 """ 

73 Get a matrix with all zeros of the correct size. 

74 

75 Returns: 

76 sparse matrix with zeros everywhere 

77 """ 

78 return 0 * self.get_Id() 

79 

80 def get_differentiation_matrix(self): 

81 raise NotImplementedError() 

82 

83 def get_integration_matrix(self): 

84 raise NotImplementedError() 

85 

86 def get_wavenumbers(self): 

87 """ 

88 Get the grid in spectral space 

89 """ 

90 raise NotImplementedError 

91 

92 def get_empty_operator_matrix(self, S, O): 

93 """ 

94 Return a matrix of operators to be filled with the connections between the solution components. 

95 

96 Args: 

97 S (int): Number of components in the solution 

98 O (sparse matrix): Zero matrix used for initialization 

99 

100 Returns: 

101 list of lists containing sparse zeros 

102 """ 

103 return [[O for _ in range(S)] for _ in range(S)] 

104 

105 def get_basis_change_matrix(self, *args, **kwargs): 

106 """ 

107 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

108 between the various bases. 

109 

110 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

111 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

112 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

113 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

114 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

115 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

116 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

117 

118 Returns: 

119 sparse bases change matrix 

120 """ 

121 return self.sparse_lib.eye(self.N) 

122 

123 def get_BC(self, kind): 

124 """ 

125 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

126 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

127 the matrix are replaced by the boundary conditions as obtained by this method. 

128 

129 Args: 

130 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

131 individual 1D bases for what is implemented 

132 

133 Returns: 

134 self.xp.array: Boundary condition 

135 """ 

136 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

137 

138 def get_filter_matrix(self, kmin=0, kmax=None): 

139 """ 

140 Get a bandpass filter. 

141 

142 Args: 

143 kmin (int): Lower limit of the bandpass filter 

144 kmax (int): Upper limit of the bandpass filter 

145 

146 Returns: 

147 sparse matrix 

148 """ 

149 

150 k = abs(self.get_wavenumbers()) 

151 

152 kmax = max(k) if kmax is None else kmax 

153 

154 mask = self.xp.logical_or(k >= kmax, k < kmin) 

155 

156 if self.useGPU: 

157 Id = self.get_Id().get() 

158 else: 

159 Id = self.get_Id() 

160 F = Id.tolil() 

161 F[:, mask] = 0 

162 return F.tocsc() 

163 

164 def get_1dgrid(self): 

165 """ 

166 Get the grid in physical space 

167 

168 Returns: 

169 self.xp.array: Grid 

170 """ 

171 raise NotImplementedError 

172 

173 

174class ChebychevHelper(SpectralHelper1D): 

175 """ 

176 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

177 between physical and spectral space by discrete cosine transform. 

178 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

179 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

180 be formulated in first order formulation. 

181 

182 This implementation is largely based on the Dedalus paper (arXiv:1905.10388). 

183 """ 

184 

185 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs): 

186 """ 

187 Constructor. 

188 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

189 may choose to run on GPUs via the `useGPU` argument. 

190 

191 Args: 

192 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

193 use the FFT from the library to compute the DCT 

194 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

195 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

196 """ 

197 assert x0 == -1 

198 assert x1 == 1 

199 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

200 self.transform_type = transform_type 

201 

202 if self.transform_type == 'fft': 

203 self.get_fft_utils() 

204 

205 self.cache = {} 

206 self.norm = self.get_norm() 

207 

208 def get_1dgrid(self): 

209 ''' 

210 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

211 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

212 need to do an affine transformation before any Chebychev business. 

213 

214 Returns: 

215 numpy.ndarray: 1D grid 

216 ''' 

217 return self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) 

218 

219 def get_wavenumbers(self): 

220 """Get the domain in spectral space""" 

221 return self.xp.arange(self.N) 

222 

223 def get_conv(self, name, N=None): 

224 ''' 

225 Get conversion matrix between different kinds of polynomials. The supported kinds are 

226 - T: Chebychev polynomials of first kind 

227 - U: Chebychev polynomials of second kind 

228 - D: Dirichlet recombination. 

229 

230 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

231 Once generates matrices are cached. So feel free to call the method as often as you like. 

232 

233 Args: 

234 name (str): Conversion code, e.g. 'T2U' 

235 N (int): Size of the matrix (optional) 

236 

237 Returns: 

238 scipy.sparse: Sparse conversion matrix 

239 ''' 

240 if name in self.cache.keys() and not N: 

241 return self.cache[name] 

242 

243 N = N if N else self.N 

244 sp = self.sparse_lib 

245 xp = self.xp 

246 

247 def get_forward_conv(name): 

248 if name == 'T2U': 

249 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0 

250 mat[:, 0] *= 2 

251 elif name == 'D2T': 

252 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

253 elif name[0] == name[-1]: 

254 mat = self.sparse_lib.eye(self.N) 

255 else: 

256 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

257 return mat 

258 

259 try: 

260 mat = get_forward_conv(name) 

261 except NotImplementedError as E: 

262 try: 

263 fwd = get_forward_conv(name[::-1]) 

264 import scipy.sparse as sp 

265 

266 if self.sparse_lib == sp: 

267 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

268 else: 

269 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

270 except NotImplementedError: 

271 raise NotImplementedError from E 

272 

273 self.cache[name] = mat 

274 return mat 

275 

276 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

277 """ 

278 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

279 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

280 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

281 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

282 

283 Args: 

284 conv (str): Conversion code, i.e. T2U 

285 

286 Returns: 

287 Sparse conversion matrix 

288 """ 

289 return self.get_conv(conv) 

290 

291 def get_integration_matrix(self, lbnd=0): 

292 """ 

293 Get matrix for integration 

294 

295 Args: 

296 lbnd (float): Lower bound for integration, only 0 is currently implemented 

297 

298 Returns: 

299 Sparse integration matrix 

300 """ 

301 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

302 n = self.xp.arange(self.N) 

303 if lbnd == 0: 

304 S = S.tocsc() 

305 S[0, 1::2] = ( 

306 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

307 * (-1) ** (self.xp.arange(self.N // 2)) 

308 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

309 ) 

310 else: 

311 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

312 return S 

313 

314 def get_differentiation_matrix(self, p=1): 

315 ''' 

316 Keep in mind that the T2T differentiation matrix is dense. 

317 

318 Args: 

319 p (int): Derivative you want to compute 

320 

321 Returns: 

322 numpy.ndarray: Differentiation matrix 

323 ''' 

324 D = self.xp.zeros((self.N, self.N)) 

325 for j in range(self.N): 

326 for k in range(j): 

327 D[k, j] = 2 * j * ((j - k) % 2) 

328 

329 D[0, :] /= 2 

330 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) 

331 

332 def get_norm(self, N=None): 

333 ''' 

334 Get normalization for converting Chebychev coefficients and DCT 

335 

336 Args: 

337 N (int, optional): Resolution 

338 

339 Returns: 

340 self.xp.array: Normalization 

341 ''' 

342 N = self.N if N is None else N 

343 norm = self.xp.ones(N) / N 

344 norm[0] /= 2 

345 return norm 

346 

347 def get_fft_shuffle(self, forward, N): 

348 """ 

349 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following 

350 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it 

351 in the complex plane. This function returns a mask to do the shuffling. 

352 

353 Args: 

354 forward (bool): Whether you want the shuffle for forward transform or backward transform 

355 N (int): size of the grid 

356 

357 Returns: 

358 self.xp.array: Use as mask 

359 """ 

360 xp = self.xp 

361 if forward: 

362 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2) 

363 else: 

364 mask = xp.zeros(N, dtype=int) 

365 mask[: N - N % 2 : 2] = xp.arange(N // 2) 

366 mask[1::2] = N - xp.arange(N // 2) - 1 

367 mask[-1] = N // 2 

368 return mask 

369 

370 def get_fft_shift(self, forward, N): 

371 """ 

372 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT. 

373 

374 Args: 

375 forward (bool): Whether you want the rotation for forward transform or backward transform 

376 N (int): size of the grid 

377 

378 Returns: 

379 self.xp.array: Rotation 

380 """ 

381 k = self.get_wavenumbers() 

382 norm = self.get_norm() 

383 xp = self.xp 

384 if forward: 

385 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm 

386 else: 

387 shift = xp.exp(1j * np.pi * k / (2 * N)) 

388 shift[0] = 0.5 

389 return shift / norm 

390 

391 def get_fft_utils(self): 

392 """ 

393 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep 

394 them cached. 

395 """ 

396 self.fft_utils = { 

397 'fwd': {}, 

398 'bck': {}, 

399 } 

400 

401 # forwards transform 

402 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N) 

403 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N) 

404 

405 # backwards transform 

406 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N) 

407 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N) 

408 

409 return self.fft_utils 

410 

411 def transform(self, u, axis=-1, **kwargs): 

412 """ 

413 1D DCT along axis. `kwargs` will be passed on to the FFT library. 

414 

415 Args: 

416 u: Data you want to transform 

417 axis (int): Axis you want to transform along 

418 

419 Returns: 

420 Data in spectral space 

421 """ 

422 if self.transform_type.lower() == 'dct': 

423 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm 

424 elif self.transform_type.lower() == 'fft': 

425 result = u.copy() 

426 

427 shuffle = [slice(0, s, 1) for s in u.shape] 

428 shuffle[axis] = self.fft_utils['fwd']['shuffle'] 

429 

430 v = u[(*shuffle,)] 

431 

432 V = self.fft_lib.fft(v, axis=axis, **kwargs) 

433 

434 expansion = [np.newaxis for _ in u.shape] 

435 expansion[axis] = slice(0, u.shape[axis], 1) 

436 

437 V *= self.fft_utils['fwd']['shift'][(*expansion,)] 

438 

439 result.real[...] = V.real[...] 

440 return result 

441 else: 

442 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}') 

443 

444 def itransform(self, u, axis=-1): 

445 """ 

446 1D inverse DCT along axis. 

447 

448 Args: 

449 u: Data you want to transform 

450 axis (int): Axis you want to transform along 

451 

452 Returns: 

453 Data in physical space 

454 """ 

455 assert self.norm.shape[0] == u.shape[axis] 

456 

457 if self.transform_type == 'dct': 

458 return self.fft_lib.idct(u / self.norm, axis=axis) 

459 elif self.transform_type == 'fft': 

460 result = u.copy() 

461 

462 expansion = [np.newaxis for _ in u.shape] 

463 expansion[axis] = slice(0, u.shape[axis], 1) 

464 

465 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis) 

466 

467 shuffle = [slice(0, s, 1) for s in u.shape] 

468 shuffle[axis] = self.fft_utils['bck']['shuffle'] 

469 V = v[(*shuffle,)] 

470 

471 result.real[...] = V.real[...] 

472 return result 

473 else: 

474 raise NotImplementedError 

475 

476 def get_BC(self, kind, **kwargs): 

477 """ 

478 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

479 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

480 set the BC. 

481 

482 Args: 

483 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

484 """ 

485 if kind.lower() == 'integral': 

486 return self.get_integ_BC_row(**kwargs) 

487 elif kind.lower() == 'dirichlet': 

488 return self.get_Dirichlet_BC_row(**kwargs) 

489 else: 

490 return super().get_BC(kind) 

491 

492 def get_integ_BC_row(self): 

493 """ 

494 Get a row for generating integral BCs with T polynomials. 

495 It returns the values of the integrals of T polynomials over the entire interval. 

496 

497 Returns: 

498 self.xp.ndarray: Row to put into a matrix 

499 """ 

500 n = self.xp.arange(self.N) + 1 

501 me = self.xp.zeros_like(n).astype(float) 

502 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

503 me[0] = 2.0 

504 return me 

505 

506 def get_Dirichlet_BC_row(self, x): 

507 """ 

508 Get a row for generating Dirichlet BCs at x with T polynomials. 

509 It returns the values of the T polynomials at x. 

510 

511 Args: 

512 x (float): Position of the boundary condition 

513 

514 Returns: 

515 self.xp.ndarray: Row to put into a matrix 

516 """ 

517 if x == -1: 

518 return (-1) ** self.xp.arange(self.N) 

519 elif x == 1: 

520 return self.xp.ones(self.N) 

521 elif x == 0: 

522 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

523 n[2::4] *= -1 

524 return n 

525 else: 

526 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

527 

528 def get_Dirichlet_recombination_matrix(self): 

529 ''' 

530 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

531 This makes for a good right preconditioner. 

532 

533 Returns: 

534 scipy.sparse: Sparse conversion matrix 

535 ''' 

536 N = self.N 

537 sp = self.sparse_lib 

538 xp = self.xp 

539 

540 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

541 

542 

543class UltrasphericalHelper(ChebychevHelper): 

544 """ 

545 This implementation follows https://doi.org/10.1137/120865458. 

546 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

547 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

548 There are also conversion operators from one derivative basis to the next that are sparse. 

549 

550 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

551 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

552 """ 

553 

554 def get_differentiation_matrix(self, p=1): 

555 """ 

556 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

557 

558 Args: 

559 p (int): Order of the derivative 

560 

561 Returns: 

562 sparse differentiation matrix 

563 """ 

564 sp = self.sparse_lib 

565 xp = self.xp 

566 N = self.N 

567 l = p 

568 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) 

569 

570 def get_S(self, lmbda): 

571 """ 

572 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

573 https://doi.org/10.1137/120865458. 

574 

575 Args: 

576 lmbda (int): Ingoing derivative base 

577 

578 Returns: 

579 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

580 """ 

581 N = self.N 

582 

583 if lmbda == 0: 

584 sp = scipy.sparse 

585 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil() 

586 mat[:, 0] *= 2 

587 else: 

588 sp = self.sparse_lib 

589 xp = self.xp 

590 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

591 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

592 ) 

593 

594 return self.sparse_lib.csc_matrix(mat) 

595 

596 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

597 """ 

598 Get a conversion matrix from derivative base `p_in` to `p_out`. 

599 

600 Args: 

601 p_out (int): Resulting derivative base 

602 p_in (int): Ingoing derivative base 

603 """ 

604 mat_fwd = self.sparse_lib.eye(self.N) 

605 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

606 mat_fwd = self.get_S(i) @ mat_fwd 

607 

608 if p_out > p_in: 

609 return mat_fwd 

610 

611 else: 

612 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

613 import scipy.sparse as sp 

614 

615 if self.useGPU: 

616 mat_fwd = mat_fwd.get() 

617 

618 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

619 

620 return self.sparse_lib.csc_matrix(mat_bck) 

621 

622 def get_integration_matrix(self): 

623 """ 

624 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

625 integration constant such that integration starts from x=-1. 

626 

627 Example: 

628 

629 .. code-block:: python 

630 

631 import numpy as np 

632 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

633 

634 N = 4 

635 helper = UltrasphericalHelper(N) 

636 coeffs = np.random.random(N) 

637 coeffs[-1] = 0 

638 

639 poly = np.polynomial.Chebyshev(coeffs) 

640 

641 S = helper.get_integration_matrix() 

642 U_hat = S @ coeffs 

643 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

644 

645 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

646 

647 Returns: 

648 sparse integration matrix 

649 """ 

650 return self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_basis_change_matrix( 

651 p_out=1, p_in=0 

652 ) 

653 

654 def get_integration_constant(self, u_hat, axis): 

655 """ 

656 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

657 

658 Args: 

659 u_hat: Solution in spectral space 

660 axis: Axis you want to integrate over 

661 

662 Returns: 

663 Integration constant, has one less dimension than `u_hat` 

664 """ 

665 slices = [ 

666 None, 

667 ] * u_hat.ndim 

668 slices[axis] = slice(1, u_hat.shape[axis]) 

669 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

670 

671 

672class FFTHelper(SpectralHelper1D): 

673 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

674 """ 

675 Constructor. 

676 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

677 may choose to run on GPUs via the `useGPU` argument. 

678 

679 Args: 

680 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

681 use the FFT from the library to compute the DCT 

682 x0 (float, optional): Coordinate of left boundary 

683 x1 (float, optional): Coordinate of right boundary 

684 """ 

685 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

686 

687 def get_1dgrid(self): 

688 """ 

689 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

690 """ 

691 dx = self.L / self.N 

692 return self.xp.arange(self.N) * dx + self.x0 

693 

694 def get_wavenumbers(self): 

695 """ 

696 Be careful that this ordering is very unintuitive. 

697 """ 

698 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

699 

700 def get_differentiation_matrix(self, p=1): 

701 """ 

702 This matrix is diagonal, allowing to invert concurrently. 

703 

704 Args: 

705 p (int): Order of the derivative 

706 

707 Returns: 

708 sparse differentiation matrix 

709 """ 

710 k = self.get_wavenumbers() 

711 

712 if self.useGPU: 

713 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

714 import scipy.sparse as sp 

715 

716 D = self.sparse_lib.diags(1j * k).get() 

717 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p)) 

718 else: 

719 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

720 

721 def get_integration_matrix(self, p=1): 

722 """ 

723 Get integration matrix to compute `p`-th integral over the entire domain. 

724 

725 Args: 

726 p (int): Order of integral you want to compute 

727 

728 Returns: 

729 sparse integration matrix 

730 """ 

731 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

732 k[0] = 1j * self.L 

733 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

734 

735 def transform(self, u, axis=-1, **kwargs): 

736 """ 

737 1D FFT along axis. `kwargs` are passed on to the FFT library. 

738 

739 Args: 

740 u: Data you want to transform 

741 axis (int): Axis you want to transform along 

742 

743 Returns: 

744 transformed data 

745 """ 

746 return self.fft_lib.fft(u, axis=axis, **kwargs) 

747 

748 def itransform(self, u, axis=-1): 

749 """ 

750 Inverse 1D FFT. 

751 

752 Args: 

753 u: Data you want to transform 

754 axis (int): Axis you want to transform along 

755 

756 Returns: 

757 transformed data 

758 """ 

759 return self.fft_lib.ifft(u, axis=axis) 

760 

761 def get_BC(self, kind): 

762 """ 

763 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

764 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

765 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

766 

767 Args: 

768 kind ('integral' or 'nyquist'): Kind of BC 

769 

770 Returns: 

771 self.xp.ndarray: Boundary condition row 

772 """ 

773 if kind.lower() == 'integral': 

774 return self.get_integ_BC_row() 

775 elif kind.lower() == 'nyquist': 

776 assert ( 

777 self.N % 2 == 0 

778 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

779 BC = self.xp.zeros(self.N) 

780 BC[self.get_Nyquist_mode_index()] = 1 

781 return BC 

782 else: 

783 return super().get_BC(kind) 

784 

785 def get_Nyquist_mode_index(self): 

786 """ 

787 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

788 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

789 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

790 after. 

791 

792 Returns: 

793 int: Index of the Nyquist mode 

794 """ 

795 k = self.get_wavenumbers() 

796 Nyquist_mode = min(k) 

797 return self.xp.where(k == Nyquist_mode)[0][0] 

798 

799 def get_integ_BC_row(self): 

800 """ 

801 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

802 """ 

803 me = self.xp.zeros(self.N) 

804 me[0] = self.L / self.N 

805 return me 

806 

807 

808class SpectralHelper: 

809 """ 

810 This class has three functions: 

811 - Easily assemble matrices containing multiple equations 

812 - Direct product of 1D bases to solve problems in more dimensions 

813 - Distribute the FFTs to facilitate concurrency. 

814 

815 Attributes: 

816 comm (mpi4py.Intracomm): MPI communicator 

817 debug (bool): Perform additional checks at extra computational cost 

818 useGPU (bool): Whether to use GPUs 

819 axes (list): List of 1D bases 

820 components (list): List of strings of the names of components in the equations 

821 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

822 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

823 BCs (sparse matrix): Matrix containing only the BCs 

824 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

825 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

826 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

827 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

828 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

829 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

830 local_slice (slice): Local slice of the solution as in `mpi4py-fft` 

831 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

832 init (tuple): This is the same `init` that is used throughout the problem classes 

833 init_forward (tuple): This is the equivalent of `init` in spectral space 

834 """ 

835 

836 xp = np 

837 fft_lib = scipy.fft 

838 sparse_lib = scipy.sparse 

839 linalg = scipy.sparse.linalg 

840 dtype = mesh 

841 fft_backend = 'fftw' 

842 fft_comm_backend = 'MPI' 

843 

844 @classmethod 

845 def setup_GPU(cls): 

846 """switch to GPU modules""" 

847 import cupy as cp 

848 import cupyx.scipy.sparse as sparse_lib 

849 import cupyx.scipy.sparse.linalg as linalg 

850 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

851 

852 cls.xp = cp 

853 cls.sparse_lib = sparse_lib 

854 cls.linalg = linalg 

855 

856 cls.fft_backend = 'cupy' 

857 cls.fft_comm_backend = 'NCCL' 

858 

859 cls.dtype = cupy_mesh 

860 

861 def __init__(self, comm=None, useGPU=False, debug=False): 

862 """ 

863 Constructor 

864 

865 Args: 

866 comm (mpi4py.Intracomm): MPI communicator 

867 useGPU (bool): Whether to use GPUs 

868 debug (bool): Perform additional checks at extra computational cost 

869 """ 

870 self.comm = comm 

871 self.debug = debug 

872 self.useGPU = useGPU 

873 

874 if useGPU: 

875 self.setup_GPU() 

876 

877 self.axes = [] 

878 self.components = [] 

879 

880 self.full_BCs = [] 

881 self.BC_mat = None 

882 self.BCs = None 

883 

884 self.fft_cache = {} 

885 

886 @property 

887 def u_init(self): 

888 """ 

889 Get empty data container in physical space 

890 """ 

891 return self.dtype(self.init) 

892 

893 @property 

894 def u_init_forward(self): 

895 """ 

896 Get empty data container in spectral space 

897 """ 

898 return self.dtype(self.init_forward) 

899 

900 @property 

901 def shape(self): 

902 """ 

903 Get shape of individual solution component 

904 """ 

905 return self.init[0][1:] 

906 

907 @property 

908 def ndim(self): 

909 return len(self.axes) 

910 

911 @property 

912 def ncomponents(self): 

913 return len(self.components) 

914 

915 @property 

916 def V(self): 

917 """ 

918 Get domain volume 

919 """ 

920 return np.prod([me.L for me in self.axes]) 

921 

922 def add_axis(self, base, *args, **kwargs): 

923 """ 

924 Add an axis to the domain by deciding on suitable 1D base. 

925 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

926 bases for possible arguments. 

927 

928 Args: 

929 base (str): 1D spectral method 

930 """ 

931 kwargs['useGPU'] = self.useGPU 

932 

933 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

934 kwargs['transform_type'] = kwargs.get('transform_type', 'fft') 

935 self.axes.append(ChebychevHelper(*args, **kwargs)) 

936 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

937 self.axes.append(FFTHelper(*args, **kwargs)) 

938 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

939 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

940 else: 

941 raise NotImplementedError(f'{base=!r} is not implemented!') 

942 self.axes[-1].xp = self.xp 

943 self.axes[-1].sparse_lib = self.sparse_lib 

944 

945 def add_component(self, name): 

946 """ 

947 Add solution component(s). 

948 

949 Args: 

950 name (str or list of strings): Name(s) of component(s) 

951 """ 

952 if type(name) in [list, tuple]: 

953 for me in name: 

954 self.add_component(me) 

955 elif type(name) in [str]: 

956 if name in self.components: 

957 raise Exception(f'{name=!r} is already added to this problem!') 

958 self.components.append(name) 

959 else: 

960 raise NotImplementedError 

961 

962 def index(self, name): 

963 """ 

964 Get the index of component `name`. 

965 

966 Args: 

967 name (str or list of strings): Name(s) of component(s) 

968 

969 Returns: 

970 int: Index of the component 

971 """ 

972 if type(name) in [str, int]: 

973 return self.components.index(name) 

974 elif type(name) in [list, tuple]: 

975 return (self.index(me) for me in name) 

976 else: 

977 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

978 

979 def get_empty_operator_matrix(self): 

980 """ 

981 Return a matrix of operators to be filled with the connections between the solution components. 

982 

983 Returns: 

984 list containing sparse zeros 

985 """ 

986 S = len(self.components) 

987 O = self.get_Id() * 0 

988 return [[O for _ in range(S)] for _ in range(S)] 

989 

990 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

991 """ 

992 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

993 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

994 Typically, the last line that does not contain a BC is the best choice. 

995 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

996 

997 Args: 

998 axis (int): Axis you want to add the BC to 

999 kind (str): kind of BC, e.g. Dirichlet 

1000 line (int): Line you want the BC to go in 

1001 scalar (bool): Put the BC in all space positions in the other direction 

1002 

1003 Returns: 

1004 sparse matrix containing the BC 

1005 """ 

1006 sp = scipy.sparse 

1007 

1008 base = self.axes[axis] 

1009 

1010 BC = sp.eye(base.N).tolil() * 0 

1011 if self.useGPU: 

1012 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1013 else: 

1014 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1015 

1016 ndim = len(self.axes) 

1017 if ndim == 1: 

1018 return self.sparse_lib.csc_matrix(BC) 

1019 elif ndim == 2: 

1020 axis2 = (axis + 1) % ndim 

1021 

1022 if scalar: 

1023 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1024 else: 

1025 _Id = self.axes[axis2].get_Id() 

1026 

1027 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1028 

1029 if self.useGPU: 

1030 Id = Id.get() 

1031 

1032 mats = [ 

1033 None, 

1034 ] * ndim 

1035 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1036 mats[axis2] = Id 

1037 return self.sparse_lib.csc_matrix(sp.kron(*mats)) 

1038 

1039 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1040 """ 

1041 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1042 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1043 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1044 

1045 Args: 

1046 component (str): Name of the component the BC should act on 

1047 equation (str): Name of the equation for the component you want to put the BC in 

1048 axis (int): Axis you want to add the BC to 

1049 kind (str): kind of BC, e.g. Dirichlet 

1050 v: Value of the BC 

1051 line (int): Line you want the BC to go in 

1052 scalar (bool): Put the BC in all space positions in the other direction 

1053 """ 

1054 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1055 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1056 

1057 if scalar: 

1058 slices = [self.index(equation)] + [ 

1059 0, 

1060 ] * self.ndim 

1061 slices[axis + 1] = line 

1062 else: 

1063 slices = ( 

1064 [self.index(equation)] 

1065 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1066 + [line] 

1067 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1068 ) 

1069 N = self.axes[axis].N 

1070 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1071 self.BC_rhs_mask[(*slices,)] = False 

1072 

1073 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1074 """ 

1075 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1076 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1077 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1078 

1079 Args: 

1080 component (str): Name of the component the BC should act on 

1081 equation (str): Name of the equation for the component you want to put the BC in 

1082 axis (int): Axis you want to add the BC to 

1083 kind (str): kind of BC, e.g. Dirichlet 

1084 v: Value of the BC 

1085 line (int): Line you want the BC to go in 

1086 scalar (bool): Put the BC in all space positions in the other direction 

1087 """ 

1088 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1089 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1090 self.full_BCs += [ 

1091 { 

1092 'component': component, 

1093 'equation': equation, 

1094 'axis': axis, 

1095 'kind': kind, 

1096 'v': v, 

1097 'line': line, 

1098 'scalar': scalar, 

1099 **kwargs, 

1100 } 

1101 ] 

1102 

1103 if scalar: 

1104 slices = [self.index(equation)] + [ 

1105 0, 

1106 ] * self.ndim 

1107 slices[axis + 1] = line 

1108 if self.comm: 

1109 if self.comm.rank == 0: 

1110 self.BC_rhs_mask[(*slices,)] = True 

1111 else: 

1112 self.BC_rhs_mask[(*slices,)] = True 

1113 else: 

1114 slices = ( 

1115 [self.index(equation)] 

1116 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1117 + [line] 

1118 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1119 ) 

1120 N = self.axes[axis].N 

1121 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1122 slices[axis + 1] -= self.local_slice[axis].start 

1123 self.BC_rhs_mask[(*slices,)] = True 

1124 

1125 def setup_BCs(self): 

1126 """ 

1127 Convert the list of lists of BCs to the boundary condition operator. 

1128 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1129 condition. This method sets up a suitable sparse matrix to do this. 

1130 """ 

1131 sp = self.sparse_lib 

1132 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1133 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1134 

1135 diags = self.xp.ones(self.BCs.shape[0]) 

1136 diags[self.BC_zero_index] = 0 

1137 self.BC_line_zero_matrix = sp.diags(diags) 

1138 

1139 # prepare BCs in spectral space to easily add to the RHS 

1140 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1141 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1142 

1143 def check_BCs(self, u): 

1144 """ 

1145 Check that the solution satisfies the boundary conditions 

1146 

1147 Args: 

1148 u: The solution you want to check 

1149 """ 

1150 assert self.ndim < 3 

1151 for axis in range(self.ndim): 

1152 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1153 

1154 if len(BCs) > 0: 

1155 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1156 for BC in BCs: 

1157 kwargs = { 

1158 key: value 

1159 for key, value in BC.items() 

1160 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1161 } 

1162 

1163 if axis == 0: 

1164 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1165 elif axis == 1: 

1166 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1167 want = BC['v'] 

1168 assert self.xp.allclose( 

1169 get, want 

1170 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1171 

1172 def put_BCs_in_matrix(self, A): 

1173 """ 

1174 Put the boundary conditions in a matrix by replacing rows with BCs. 

1175 """ 

1176 return self.BC_line_zero_matrix @ A + self.BCs 

1177 

1178 def put_BCs_in_rhs_hat(self, rhs_hat): 

1179 """ 

1180 Put the BCs in the right hand side in spectral space for solving. 

1181 This function needs no transforms and caches a mask for faster subsequent use. 

1182 

1183 Args: 

1184 rhs_hat: Right hand side in spectral space 

1185 

1186 Returns: 

1187 rhs in spectral space with BCs 

1188 """ 

1189 if not hasattr(self, '_rhs_hat_zero_mask'): 

1190 """ 

1191 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1192 by the boundary conditions. The mask is then cached. 

1193 """ 

1194 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool) 

1195 

1196 for axis in range(self.ndim): 

1197 for bc in self.full_BCs: 

1198 slices = ( 

1199 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1200 + [bc['line']] 

1201 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1202 ) 

1203 if axis == bc['axis']: 

1204 _slice = [self.index(bc['equation'])] + slices 

1205 N = self.axes[axis].N 

1206 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1207 _slice[axis + 1] -= self.local_slice[axis].start 

1208 self._rhs_hat_zero_mask[(*_slice,)] = True 

1209 

1210 rhs_hat[self._rhs_hat_zero_mask] = 0 

1211 return rhs_hat + self.rhs_BCs_hat 

1212 

1213 def put_BCs_in_rhs(self, rhs): 

1214 """ 

1215 Put the BCs in the right hand side for solving. 

1216 This function will transform along each axis individually and add all BCs in that axis. 

1217 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1218 

1219 Args: 

1220 rhs: Right hand side in physical space 

1221 

1222 Returns: 

1223 rhs in physical space with BCs 

1224 """ 

1225 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1226 

1227 ndim = self.ndim 

1228 

1229 for axis in range(ndim): 

1230 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1231 

1232 for bc in self.full_BCs: 

1233 slices = ( 

1234 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1235 + [bc['line']] 

1236 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1237 ) 

1238 if axis == bc['axis']: 

1239 _slice = [self.index(bc['equation'])] + slices 

1240 

1241 N = self.axes[axis].N 

1242 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1243 _slice[axis + 1] -= self.local_slice[axis].start 

1244 

1245 _rhs_hat[(*_slice,)] = bc['v'] 

1246 

1247 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1248 

1249 return rhs 

1250 

1251 def add_equation_lhs(self, A, equation, relations): 

1252 """ 

1253 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1254 that you will convert to an operator later. 

1255 

1256 Example: 

1257 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1258 preconditioning: 

1259 

1260 .. code-block:: python 

1261 helper = SpectralHelper() 

1262 

1263 helper.add_axis(base='chebychev', N=8) 

1264 helper.add_component(['u', 'ux']) 

1265 helper.setup_fft() 

1266 

1267 I = helper.get_Id() 

1268 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1269 T2U = helper.get_basis_change_matrix('T2U') 

1270 

1271 L_lhs = { 

1272 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1273 'u': {'ux': -(T2U @ Dx)}, 

1274 } 

1275 

1276 operator = helper.get_empty_operator_matrix() 

1277 for line, equation in L_lhs.items(): 

1278 helper.add_equation_lhs(operator, line, equation) 

1279 

1280 L = helper.convert_operator_matrix_to_operator(operator) 

1281 

1282 Args: 

1283 A (list of lists of sparse matrices): The operator to be 

1284 equation (str): The equation of the component you want this in 

1285 relations: (dict): Relations between quantities 

1286 """ 

1287 for k, v in relations.items(): 

1288 A[self.index(equation)][self.index(k)] = v 

1289 

1290 def convert_operator_matrix_to_operator(self, M): 

1291 """ 

1292 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1293 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1294 

1295 Args: 

1296 M (list of lists of sparse matrices): The operator to be 

1297 

1298 Returns: 

1299 sparse linear operator 

1300 """ 

1301 if len(self.components) == 1: 

1302 return M[0][0] 

1303 else: 

1304 return self.sparse_lib.bmat(M, format='csc') 

1305 

1306 def get_wavenumbers(self): 

1307 """ 

1308 Get grid in spectral space 

1309 """ 

1310 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))][::-1] 

1311 return self.xp.meshgrid(*grids) 

1312 

1313 def get_grid(self): 

1314 """ 

1315 Get grid in physical space 

1316 """ 

1317 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))][::-1] 

1318 return self.xp.meshgrid(*grids) 

1319 

1320 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1321 """ 

1322 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1323 

1324 Args: 

1325 axes (tuple): Axes you want to transform over 

1326 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1327 padding (tuple): Padding for dealiasing 

1328 shape (tuple): Shape of the transform 

1329 

1330 Returns: 

1331 transform 

1332 """ 

1333 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1334 shape = self.global_shape[1:] if shape is None else shape 

1335 padding = ( 

1336 [ 

1337 1, 

1338 ] 

1339 * self.ndim 

1340 if padding is None 

1341 else padding 

1342 ) 

1343 key = (axes, direction, tuple(padding), tuple(shape)) 

1344 

1345 if key not in self.fft_cache.keys(): 

1346 if self.comm is None: 

1347 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1348 

1349 if direction == 'forward': 

1350 self.fft_cache[key] = self.xp.fft.fftn 

1351 elif direction == 'backward': 

1352 self.fft_cache[key] = self.xp.fft.ifftn 

1353 elif direction == 'object': 

1354 self.fft_cache[key] = None 

1355 else: 

1356 if direction == 'object': 

1357 from mpi4py_fft import PFFT 

1358 

1359 _fft = PFFT( 

1360 comm=self.comm, 

1361 shape=shape, 

1362 axes=sorted(axes), 

1363 dtype='D', 

1364 collapse=False, 

1365 backend=self.fft_backend, 

1366 comm_backend=self.fft_comm_backend, 

1367 padding=padding, 

1368 ) 

1369 else: 

1370 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1371 

1372 if direction == 'forward': 

1373 self.fft_cache[key] = _fft.forward 

1374 elif direction == 'backward': 

1375 self.fft_cache[key] = _fft.backward 

1376 elif direction == 'object': 

1377 self.fft_cache[key] = _fft 

1378 

1379 return self.fft_cache[key] 

1380 

1381 def setup_fft(self, real_spectral_coefficients=False): 

1382 """ 

1383 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1384 This must also be called before setting up any BCs. 

1385 

1386 Args: 

1387 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1388 """ 

1389 if len(self.components) == 0: 

1390 self.add_component('u') 

1391 

1392 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1393 self.local_slice = [slice(0, me.N) for me in self.axes] 

1394 

1395 axes = tuple(i for i in range(len(self.axes))) 

1396 self.fft_obj = self.get_fft(axes=axes, direction='object') 

1397 if self.fft_obj is not None: 

1398 self.local_slice = self.fft_obj.local_slice(False) 

1399 

1400 self.init = ( 

1401 np.empty(shape=self.global_shape)[ 

1402 ( 

1403 ..., 

1404 *self.local_slice, 

1405 ) 

1406 ].shape, 

1407 self.comm, 

1408 np.dtype('float'), 

1409 ) 

1410 self.init_forward = ( 

1411 np.empty(shape=self.global_shape)[ 

1412 ( 

1413 ..., 

1414 *self.local_slice, 

1415 ) 

1416 ].shape, 

1417 self.comm, 

1418 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1419 ) 

1420 

1421 self.BC_mat = self.get_empty_operator_matrix() 

1422 self.BC_rhs_mask = self.xp.zeros( 

1423 shape=self.init[0], 

1424 dtype=bool, 

1425 ) 

1426 

1427 def _transform_fft(self, u, axes, **kwargs): 

1428 """ 

1429 FFT along `axes` 

1430 

1431 Args: 

1432 u: The solution 

1433 axes (tuple): Axes you want to transform over 

1434 

1435 Returns: 

1436 transformed solution 

1437 """ 

1438 # TODO: clean up and try putting more of this in the 1D bases 

1439 fft = self.get_fft(axes, 'forward', **kwargs) 

1440 return fft(u, axes=axes) 

1441 

1442 def _transform_dct(self, u, axes, padding=None, **kwargs): 

1443 ''' 

1444 DCT along `axes`. 

1445 This will only return real values! 

1446 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of 

1447 wavenumbers in FFTs. 

1448 

1449 Args: 

1450 u: The solution 

1451 axes (tuple): Axes you want to transform over 

1452 

1453 Returns: 

1454 transformed solution 

1455 ''' 

1456 # TODO: clean up and try putting more of this in the 1D bases 

1457 if self.debug: 

1458 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1459 

1460 if len(axes) > 1: 

1461 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs) 

1462 else: 

1463 v = u.copy().astype(complex) 

1464 axis = axes[0] 

1465 base = self.axes[axis] 

1466 

1467 shuffle = [slice(0, s, 1) for s in u.shape] 

1468 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis]) 

1469 v = v[(*shuffle,)] 

1470 

1471 if padding is not None: 

1472 shape = list(v.shape) 

1473 if self.comm: 

1474 send_buf = np.array(v.shape[0]) 

1475 recv_buf = np.array(v.shape[0]) 

1476 self.comm.Allreduce(send_buf, recv_buf) 

1477 shape[0] = int(recv_buf) 

1478 fft = self.get_fft(axes, 'forward', shape=shape) 

1479 else: 

1480 fft = self.get_fft(axes, 'forward', **kwargs) 

1481 

1482 v = fft(v, axes=axes) 

1483 

1484 expansion = [np.newaxis for _ in u.shape] 

1485 expansion[axis] = slice(0, v.shape[axis], 1) 

1486 

1487 if padding is not None: 

1488 shift = base.get_fft_shift(True, v.shape[axis]) 

1489 

1490 if padding[axis] != 1: 

1491 N = int(np.ceil(v.shape[axis] / padding[axis])) 

1492 _expansion = [slice(0, n) for n in v.shape] 

1493 _expansion[axis] = slice(0, N, 1) 

1494 v = v[(*_expansion,)] 

1495 else: 

1496 shift = base.fft_utils['fwd']['shift'] 

1497 

1498 v *= shift[(*expansion,)] 

1499 

1500 return v.real 

1501 

1502 def transform_single_component(self, u, axes=None, padding=None): 

1503 """ 

1504 Transform a single component of the solution 

1505 

1506 Args: 

1507 u data to transform: 

1508 axes (tuple): Axes over which to transform 

1509 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1510 

1511 Returns: 

1512 Transformed data 

1513 """ 

1514 # TODO: clean up and try putting more of this in the 1D bases 

1515 trfs = { 

1516 ChebychevHelper: self._transform_dct, 

1517 UltrasphericalHelper: self._transform_dct, 

1518 FFTHelper: self._transform_fft, 

1519 } 

1520 

1521 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1522 padding = ( 

1523 [ 

1524 1, 

1525 ] 

1526 * self.ndim 

1527 if padding is None 

1528 else padding 

1529 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only. 

1530 

1531 result = u.copy().astype(complex) 

1532 alignment = self.ndim - 1 

1533 

1534 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1535 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1536 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1537 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)] 

1538 

1539 fft = self.get_fft(axes=axes, padding=padding, direction='object') 

1540 if fft is not None: 

1541 shape = list(fft.global_shape(False)) 

1542 

1543 for trf in range(len(axes_collapsed)): 

1544 _axes = axes_collapsed[trf] 

1545 base = bases[trf] 

1546 

1547 if len(_axes) == 0: 

1548 continue 

1549 

1550 for _ax in _axes: 

1551 shape[_ax] = self.global_shape[1 + self.ndim + _ax] 

1552 

1553 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1554 

1555 _in = self.get_aligned( 

1556 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape 

1557 ) 

1558 

1559 alignment = self.ndim + _axes[-1] 

1560 

1561 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1562 

1563 if self.comm is not None: 

1564 _out *= np.prod([self.axes[i].N for i in _axes]) 

1565 

1566 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1567 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1] 

1568 result = self.get_aligned( 

1569 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape 

1570 ) 

1571 

1572 return result 

1573 

1574 def transform(self, u, axes=None, padding=None): 

1575 """ 

1576 Transform all components from physical space to spectral space 

1577 

1578 Args: 

1579 u data to transform: 

1580 axes (tuple): Axes over which to transform 

1581 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1582 

1583 Returns: 

1584 Transformed data 

1585 """ 

1586 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1587 padding = ( 

1588 [ 

1589 1, 

1590 ] 

1591 * self.ndim 

1592 if padding is None 

1593 else padding 

1594 ) 

1595 

1596 result = [ 

1597 None, 

1598 ] * self.ncomponents 

1599 for comp in self.components: 

1600 i = self.index(comp) 

1601 

1602 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding) 

1603 

1604 return self.xp.stack(result) 

1605 

1606 def _transform_ifft(self, u, axes, **kwargs): 

1607 # TODO: clean up and try putting more of this in the 1D bases 

1608 ifft = self.get_fft(axes, 'backward', **kwargs) 

1609 return ifft(u, axes=axes) 

1610 

1611 def _transform_idct(self, u, axes, padding=None, **kwargs): 

1612 ''' 

1613 This will only ever return real values! 

1614 ''' 

1615 # TODO: clean up and try putting more of this in the 1D bases 

1616 if self.debug: 

1617 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1618 

1619 v = u.copy().astype(complex) 

1620 

1621 if len(axes) > 1: 

1622 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],)) 

1623 else: 

1624 axis = axes[0] 

1625 base = self.axes[axis] 

1626 

1627 if padding is not None: 

1628 if padding[axis] != 1: 

1629 N_pad = int(np.ceil(v.shape[axis] * padding[axis])) 

1630 _pad = [[0, 0] for _ in v.shape] 

1631 _pad[axis] = [0, N_pad - base.N] 

1632 v = self.xp.pad(v, _pad, 'constant') 

1633 

1634 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N 

1635 else: 

1636 shift = base.fft_utils['bck']['shift'] 

1637 else: 

1638 shift = base.fft_utils['bck']['shift'] 

1639 

1640 expansion = [np.newaxis for _ in u.shape] 

1641 expansion[axis] = slice(0, v.shape[axis], 1) 

1642 

1643 v *= shift[(*expansion,)] 

1644 

1645 if padding is not None: 

1646 if padding[axis] != 1: 

1647 shape = list(v.shape) 

1648 if self.comm: 

1649 send_buf = np.array(v.shape[0]) 

1650 recv_buf = np.array(v.shape[0]) 

1651 self.comm.Allreduce(send_buf, recv_buf) 

1652 shape[0] = int(recv_buf) 

1653 ifft = self.get_fft(axes, 'backward', shape=shape) 

1654 else: 

1655 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1656 else: 

1657 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1658 v = ifft(v, axes=axes) 

1659 

1660 shuffle = [slice(0, s, 1) for s in v.shape] 

1661 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis]) 

1662 v = v[(*shuffle,)] 

1663 

1664 return v.real 

1665 

1666 def itransform_single_component(self, u, axes=None, padding=None): 

1667 """ 

1668 Inverse transform over single component of the solution 

1669 

1670 Args: 

1671 u data to transform: 

1672 axes (tuple): Axes over which to transform 

1673 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1674 

1675 Returns: 

1676 Transformed data 

1677 """ 

1678 # TODO: clean up and try putting more of this in the 1D bases 

1679 trfs = { 

1680 FFTHelper: self._transform_ifft, 

1681 ChebychevHelper: self._transform_idct, 

1682 UltrasphericalHelper: self._transform_idct, 

1683 } 

1684 

1685 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1686 padding = ( 

1687 [ 

1688 1, 

1689 ] 

1690 * self.ndim 

1691 if padding is None 

1692 else padding 

1693 ) 

1694 

1695 result = u.copy().astype(complex) 

1696 alignment = self.ndim - 1 

1697 

1698 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1699 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1700 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1701 shape = list(self.global_shape[1:]) 

1702 

1703 for trf in range(len(axes_collapsed)): 

1704 _axes = axes_collapsed[trf] 

1705 base = bases[trf] 

1706 

1707 if len(_axes) == 0: 

1708 continue 

1709 

1710 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1711 

1712 _in = self.get_aligned( 

1713 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape 

1714 ) 

1715 if self.comm is not None: 

1716 _in /= np.prod([self.axes[i].N for i in _axes]) 

1717 

1718 alignment = self.ndim + _axes[0] 

1719 

1720 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1721 

1722 for _ax in _axes: 

1723 if fft: 

1724 shape[_ax] = fft._input_shape[_ax] 

1725 else: 

1726 shape[_ax] = _out.shape[_ax] 

1727 

1728 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1729 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0] 

1730 result = self.get_aligned( 

1731 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape 

1732 ) 

1733 

1734 return result 

1735 

1736 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs): 

1737 """ 

1738 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT 

1739 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include 

1740 shape especially, if applicable. 

1741 

1742 Args: 

1743 u: The solution 

1744 axis_in (int): Current alignment 

1745 axis_out (int): New alignment 

1746 fft (mpi4py_fft.PFFT), optional: parallel FFT object 

1747 forward (bool): Whether the input is in spectral space or not 

1748 

1749 Returns: 

1750 solution aligned on `axis_in` 

1751 """ 

1752 if self.comm is None or axis_in == axis_out: 

1753 return u.copy() 

1754 if self.comm.size == 1: 

1755 return u.copy() 

1756 

1757 fft = self.get_fft(**kwargs) if fft is None else fft 

1758 

1759 global_fft = self.get_fft(**kwargs) 

1760 axisA = [me.axisA for me in global_fft.transfer] 

1761 axisB = [me.axisB for me in global_fft.transfer] 

1762 

1763 current_axis = axis_in 

1764 

1765 if axis_in in axisA and axis_out in axisB: 

1766 while current_axis != axis_out: 

1767 transfer = global_fft.transfer[axisA.index(current_axis)] 

1768 

1769 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1770 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1771 arrayA[:] = u[:] 

1772 

1773 transfer.forward(arrayA=arrayA, arrayB=arrayB) 

1774 

1775 current_axis = transfer.axisB 

1776 u = arrayB 

1777 

1778 return u 

1779 elif axis_in in axisB and axis_out in axisA: 

1780 while current_axis != axis_out: 

1781 transfer = global_fft.transfer[axisB.index(current_axis)] 

1782 

1783 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1784 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1785 arrayB[:] = u[:] 

1786 

1787 transfer.backward(arrayA=arrayA, arrayB=arrayB) 

1788 

1789 current_axis = transfer.axisA 

1790 u = arrayA 

1791 

1792 return u 

1793 else: # go the potentially slower route of not reusing transfer classes 

1794 from mpi4py_fft import newDistArray 

1795 

1796 _in = newDistArray(fft, forward).redistribute(axis_in) 

1797 _in[...] = u 

1798 

1799 return _in.redistribute(axis_out) 

1800 

1801 def itransform(self, u, axes=None, padding=None): 

1802 """ 

1803 Inverse transform over all components of the solution 

1804 

1805 Args: 

1806 u data to transform: 

1807 axes (tuple): Axes over which to transform 

1808 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1809 

1810 Returns: 

1811 Transformed data 

1812 """ 

1813 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1814 padding = ( 

1815 [ 

1816 1, 

1817 ] 

1818 * self.ndim 

1819 if padding is None 

1820 else padding 

1821 ) 

1822 

1823 result = [ 

1824 None, 

1825 ] * self.ncomponents 

1826 for comp in self.components: 

1827 i = self.index(comp) 

1828 

1829 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding) 

1830 

1831 return self.xp.stack(result) 

1832 

1833 def get_local_slice_of_1D_matrix(self, M, axis): 

1834 """ 

1835 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1836 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can 

1837 use this method to get the part corresponding to the modes carried by this rank. 

1838 

1839 Args: 

1840 M (sparse matrix): Global 1D matrix you want to get the local version of 

1841 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only. 

1842 

1843 Returns: 

1844 sparse local matrix 

1845 """ 

1846 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]] 

1847 

1848 def get_filter_matrix(self, axis, **kwargs): 

1849 """ 

1850 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1851 admissible. 

1852 

1853 Returns: 

1854 sparse bandpass matrix 

1855 """ 

1856 if self.ndim == 1: 

1857 return self.axes[0].get_filter_matrix(**kwargs) 

1858 

1859 mats = [base.get_Id() for base in self.axes] 

1860 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1861 return self.sparse_lib.kron(*mats) 

1862 

1863 def get_differentiation_matrix(self, axes, **kwargs): 

1864 """ 

1865 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1866 

1867 Args: 

1868 axes (tuple): Axes along which to differentiate. 

1869 

1870 Returns: 

1871 sparse differentiation matrix 

1872 """ 

1873 sp = self.sparse_lib 

1874 ndim = self.ndim 

1875 

1876 if ndim == 1: 

1877 D = self.axes[0].get_differentiation_matrix(**kwargs) 

1878 elif ndim == 2: 

1879 for axis in axes: 

1880 axis2 = (axis + 1) % ndim 

1881 D1D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1882 

1883 if len(axes) > 1: 

1884 I1D = sp.eye(self.axes[axis2].N) 

1885 else: 

1886 I1D = self.axes[axis2].get_Id() 

1887 

1888 mats = [None] * ndim 

1889 mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis) 

1890 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1891 

1892 if axis == axes[0]: 

1893 D = sp.kron(*mats) 

1894 else: 

1895 D = D @ sp.kron(*mats) 

1896 else: 

1897 raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!') 

1898 

1899 return D 

1900 

1901 def get_integration_matrix(self, axes): 

1902 """ 

1903 Get integration matrix to integrate along specified axis. 

1904 

1905 Args: 

1906 axes (tuple): Axes along which to integrate over. 

1907 

1908 Returns: 

1909 sparse integration matrix 

1910 """ 

1911 sp = self.sparse_lib 

1912 ndim = len(self.axes) 

1913 

1914 if ndim == 1: 

1915 S = self.axes[0].get_integration_matrix() 

1916 elif ndim == 2: 

1917 for axis in axes: 

1918 axis2 = (axis + 1) % ndim 

1919 S1D = self.axes[axis].get_integration_matrix() 

1920 

1921 if len(axes) > 1: 

1922 I1D = sp.eye(self.axes[axis2].N) 

1923 else: 

1924 I1D = self.axes[axis2].get_Id() 

1925 

1926 mats = [None] * ndim 

1927 mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis) 

1928 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1929 

1930 if axis == axes[0]: 

1931 S = sp.kron(*mats) 

1932 else: 

1933 S = S @ sp.kron(*mats) 

1934 else: 

1935 raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!') 

1936 

1937 return S 

1938 

1939 def get_Id(self): 

1940 """ 

1941 Get identity matrix 

1942 

1943 Returns: 

1944 sparse identity matrix 

1945 """ 

1946 sp = self.sparse_lib 

1947 ndim = self.ndim 

1948 I = sp.eye(np.prod(self.init[0][1:]), dtype=complex) 

1949 

1950 if ndim == 1: 

1951 I = self.axes[0].get_Id() 

1952 elif ndim == 2: 

1953 for axis in range(ndim): 

1954 axis2 = (axis + 1) % ndim 

1955 I1D = self.axes[axis].get_Id() 

1956 

1957 I1D2 = sp.eye(self.axes[axis2].N) 

1958 

1959 mats = [None] * ndim 

1960 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1961 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2) 

1962 

1963 I = I @ sp.kron(*mats) 

1964 else: 

1965 raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!') 

1966 

1967 return I 

1968 

1969 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1970 """ 

1971 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1972 

1973 Args: 

1974 axis (int): Axis you discretized with Chebychev 

1975 

1976 Returns: 

1977 sparse matrix 

1978 """ 

1979 sp = self.sparse_lib 

1980 ndim = len(self.axes) 

1981 

1982 if ndim == 1: 

1983 C = self.axes[0].get_Dirichlet_recombination_matrix() 

1984 elif ndim == 2: 

1985 axis2 = (axis + 1) % ndim 

1986 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1987 

1988 I1D = self.axes[axis2].get_Id() 

1989 

1990 mats = [None] * ndim 

1991 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) 

1992 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1993 

1994 C = sp.kron(*mats) 

1995 else: 

1996 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') 

1997 

1998 return C 

1999 

2000 def get_basis_change_matrix(self, axes=None, **kwargs): 

2001 """ 

2002 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

2003 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

2004 

2005 Args: 

2006 axes (tuple): Axes along which to change basis. 

2007 

2008 Returns: 

2009 sparse basis change matrix 

2010 """ 

2011 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

2012 

2013 sp = self.sparse_lib 

2014 ndim = len(self.axes) 

2015 

2016 if ndim == 1: 

2017 C = self.axes[0].get_basis_change_matrix(**kwargs) 

2018 elif ndim == 2: 

2019 for axis in axes: 

2020 axis2 = (axis + 1) % ndim 

2021 C1D = self.axes[axis].get_basis_change_matrix(**kwargs) 

2022 

2023 if len(axes) > 1: 

2024 I1D = sp.eye(self.axes[axis2].N) 

2025 else: 

2026 I1D = self.axes[axis2].get_Id() 

2027 

2028 mats = [None] * ndim 

2029 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) 

2030 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

2031 

2032 if axis == axes[0]: 

2033 C = sp.kron(*mats) 

2034 else: 

2035 C = C @ sp.kron(*mats) 

2036 else: 

2037 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') 

2038 

2039 return C