Coverage for pySDC/helpers/spectral_helper.py: 92%

778 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 10:09 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5 

6 

7class SpectralHelper1D: 

8 """ 

9 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

10 all bases need to have. 

11 

12 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

13 the code for GPUs. 

14 

15 Attributes: 

16 N (int): Resolution 

17 x0 (float): Coordinate of left boundary 

18 x1 (float): Coordinate of right boundary 

19 L (float): Length of the domain 

20 useGPU (bool): Whether to use GPUs 

21 

22 """ 

23 

24 fft_lib = scipy.fft 

25 sparse_lib = scipy.sparse 

26 linalg = scipy.sparse.linalg 

27 xp = np 

28 

29 def __init__(self, N, x0=None, x1=None, useGPU=False): 

30 """ 

31 Constructor 

32 

33 Args: 

34 N (int): Resolution 

35 x0 (float): Coordinate of left boundary 

36 x1 (float): Coordinate of right boundary 

37 useGPU (bool): Whether to use GPUs 

38 """ 

39 self.N = N 

40 self.x0 = x0 

41 self.x1 = x1 

42 self.L = x1 - x0 

43 self.useGPU = useGPU 

44 

45 if useGPU: 

46 self.setup_GPU() 

47 

48 @classmethod 

49 def setup_GPU(cls): 

50 """switch to GPU modules""" 

51 import cupy as cp 

52 import cupyx.scipy.sparse as sparse_lib 

53 import cupyx.scipy.sparse.linalg as linalg 

54 import cupyx.scipy.fft as fft_lib 

55 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

56 

57 cls.xp = cp 

58 cls.sparse_lib = sparse_lib 

59 cls.linalg = linalg 

60 cls.fft_lib = fft_lib 

61 

62 def get_Id(self): 

63 """ 

64 Get identity matrix 

65 

66 Returns: 

67 sparse diagonal identity matrix 

68 """ 

69 return self.sparse_lib.eye(self.N) 

70 

71 def get_zero(self): 

72 """ 

73 Get a matrix with all zeros of the correct size. 

74 

75 Returns: 

76 sparse matrix with zeros everywhere 

77 """ 

78 return 0 * self.get_Id() 

79 

80 def get_differentiation_matrix(self): 

81 raise NotImplementedError() 

82 

83 def get_integration_matrix(self): 

84 raise NotImplementedError() 

85 

86 def get_wavenumbers(self): 

87 """ 

88 Get the grid in spectral space 

89 """ 

90 raise NotImplementedError 

91 

92 def get_empty_operator_matrix(self, S, O): 

93 """ 

94 Return a matrix of operators to be filled with the connections between the solution components. 

95 

96 Args: 

97 S (int): Number of components in the solution 

98 O (sparse matrix): Zero matrix used for initialization 

99 

100 Returns: 

101 list of lists containing sparse zeros 

102 """ 

103 return [[O for _ in range(S)] for _ in range(S)] 

104 

105 def get_basis_change_matrix(self, *args, **kwargs): 

106 """ 

107 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

108 between the various bases. 

109 

110 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

111 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

112 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

113 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

114 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

115 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

116 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

117 

118 Returns: 

119 sparse bases change matrix 

120 """ 

121 return self.sparse_lib.eye(self.N) 

122 

123 def get_BC(self, kind): 

124 """ 

125 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

126 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

127 the matrix are replaced by the boundary conditions as obtained by this method. 

128 

129 Args: 

130 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

131 individual 1D bases for what is implemented 

132 

133 Returns: 

134 self.xp.array: Boundary condition 

135 """ 

136 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

137 

138 def get_filter_matrix(self, kmin=0, kmax=None): 

139 """ 

140 Get a bandpass filter. 

141 

142 Args: 

143 kmin (int): Lower limit of the bandpass filter 

144 kmax (int): Upper limit of the bandpass filter 

145 

146 Returns: 

147 sparse matrix 

148 """ 

149 

150 k = abs(self.get_wavenumbers()) 

151 

152 kmax = max(k) if kmax is None else kmax 

153 

154 mask = self.xp.logical_or(k >= kmax, k < kmin) 

155 

156 if self.useGPU: 

157 Id = self.get_Id().get() 

158 else: 

159 Id = self.get_Id() 

160 F = Id.tolil() 

161 F[:, mask] = 0 

162 return F.tocsc() 

163 

164 def get_1dgrid(self): 

165 """ 

166 Get the grid in physical space 

167 

168 Returns: 

169 self.xp.array: Grid 

170 """ 

171 raise NotImplementedError 

172 

173 

174class ChebychevHelper(SpectralHelper1D): 

175 """ 

176 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

177 between physical and spectral space by discrete cosine transform. 

178 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

179 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

180 be formulated in first order formulation. 

181 

182 This implementation is largely based on the Dedalus paper (arXiv:1905.10388). 

183 """ 

184 

185 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs): 

186 """ 

187 Constructor. 

188 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

189 may choose to run on GPUs via the `useGPU` argument. 

190 

191 Args: 

192 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

193 use the FFT from the library to compute the DCT 

194 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

195 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

196 """ 

197 assert x0 == -1 

198 assert x1 == 1 

199 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

200 self.transform_type = transform_type 

201 

202 if self.transform_type == 'fft': 

203 self.get_fft_utils() 

204 

205 self.cache = {} 

206 self.norm = self.get_norm() 

207 

208 def get_1dgrid(self): 

209 ''' 

210 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

211 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

212 need to do an affine transformation before any Chebychev business. 

213 

214 Returns: 

215 numpy.ndarray: 1D grid 

216 ''' 

217 return self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) 

218 

219 def get_wavenumbers(self): 

220 """Get the domain in spectral space""" 

221 return self.xp.arange(self.N) 

222 

223 def get_conv(self, name, N=None): 

224 ''' 

225 Get conversion matrix between different kinds of polynomials. The supported kinds are 

226 - T: Chebychev polynomials of first kind 

227 - U: Chebychev polynomials of second kind 

228 - D: Dirichlet recombination. 

229 

230 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

231 Once generates matrices are cached. So feel free to call the method as often as you like. 

232 

233 Args: 

234 name (str): Conversion code, e.g. 'T2U' 

235 N (int): Size of the matrix (optional) 

236 

237 Returns: 

238 scipy.sparse: Sparse conversion matrix 

239 ''' 

240 if name in self.cache.keys() and not N: 

241 return self.cache[name] 

242 

243 N = N if N else self.N 

244 sp = self.sparse_lib 

245 xp = self.xp 

246 

247 def get_forward_conv(name): 

248 if name == 'T2U': 

249 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0 

250 mat[:, 0] *= 2 

251 elif name == 'D2T': 

252 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

253 elif name[0] == name[-1]: 

254 mat = self.sparse_lib.eye(self.N) 

255 else: 

256 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

257 return mat 

258 

259 try: 

260 mat = get_forward_conv(name) 

261 except NotImplementedError as E: 

262 try: 

263 fwd = get_forward_conv(name[::-1]) 

264 import scipy.sparse as sp 

265 

266 if self.sparse_lib == sp: 

267 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

268 else: 

269 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

270 except NotImplementedError: 

271 raise NotImplementedError from E 

272 

273 self.cache[name] = mat 

274 return mat 

275 

276 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

277 """ 

278 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

279 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

280 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

281 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

282 

283 Args: 

284 conv (str): Conversion code, i.e. T2U 

285 

286 Returns: 

287 Sparse conversion matrix 

288 """ 

289 return self.get_conv(conv) 

290 

291 def get_integration_matrix(self, lbnd=0): 

292 """ 

293 Get matrix for integration 

294 

295 Args: 

296 lbnd (float): Lower bound for integration, only 0 is currently implemented 

297 

298 Returns: 

299 Sparse integration matrix 

300 """ 

301 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

302 n = self.xp.arange(self.N) 

303 if lbnd == 0: 

304 S = S.tocsc() 

305 S[0, 1::2] = ( 

306 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

307 * (-1) ** (self.xp.arange(self.N // 2)) 

308 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

309 ) 

310 else: 

311 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

312 return S 

313 

314 def get_differentiation_matrix(self, p=1): 

315 ''' 

316 Keep in mind that the T2T differentiation matrix is dense. 

317 

318 Args: 

319 p (int): Derivative you want to compute 

320 

321 Returns: 

322 numpy.ndarray: Differentiation matrix 

323 ''' 

324 D = self.xp.zeros((self.N, self.N)) 

325 for j in range(self.N): 

326 for k in range(j): 

327 D[k, j] = 2 * j * ((j - k) % 2) 

328 

329 D[0, :] /= 2 

330 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) 

331 

332 def get_norm(self, N=None): 

333 ''' 

334 Get normalization for converting Chebychev coefficients and DCT 

335 

336 Args: 

337 N (int, optional): Resolution 

338 

339 Returns: 

340 self.xp.array: Normalization 

341 ''' 

342 N = self.N if N is None else N 

343 norm = self.xp.ones(N) / N 

344 norm[0] /= 2 

345 return norm 

346 

347 def get_fft_shuffle(self, forward, N): 

348 """ 

349 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following 

350 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it 

351 in the complex plane. This function returns a mask to do the shuffling. 

352 

353 Args: 

354 forward (bool): Whether you want the shuffle for forward transform or backward transform 

355 N (int): size of the grid 

356 

357 Returns: 

358 self.xp.array: Use as mask 

359 """ 

360 xp = self.xp 

361 if forward: 

362 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2) 

363 else: 

364 mask = xp.zeros(N, dtype=int) 

365 mask[: N - N % 2 : 2] = xp.arange(N // 2) 

366 mask[1::2] = N - xp.arange(N // 2) - 1 

367 mask[-1] = N // 2 

368 return mask 

369 

370 def get_fft_shift(self, forward, N): 

371 """ 

372 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT. 

373 

374 Args: 

375 forward (bool): Whether you want the rotation for forward transform or backward transform 

376 N (int): size of the grid 

377 

378 Returns: 

379 self.xp.array: Rotation 

380 """ 

381 k = self.get_wavenumbers() 

382 norm = self.get_norm() 

383 xp = self.xp 

384 if forward: 

385 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm 

386 else: 

387 shift = xp.exp(1j * np.pi * k / (2 * N)) 

388 shift[0] = 0.5 

389 return shift / norm 

390 

391 def get_fft_utils(self): 

392 """ 

393 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep 

394 them cached. 

395 """ 

396 self.fft_utils = { 

397 'fwd': {}, 

398 'bck': {}, 

399 } 

400 

401 # forwards transform 

402 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N) 

403 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N) 

404 

405 # backwards transform 

406 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N) 

407 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N) 

408 

409 return self.fft_utils 

410 

411 def transform(self, u, axis=-1, **kwargs): 

412 """ 

413 1D DCT along axis. `kwargs` will be passed on to the FFT library. 

414 

415 Args: 

416 u: Data you want to transform 

417 axis (int): Axis you want to transform along 

418 

419 Returns: 

420 Data in spectral space 

421 """ 

422 if self.transform_type.lower() == 'dct': 

423 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm 

424 elif self.transform_type.lower() == 'fft': 

425 result = u.copy() 

426 

427 shuffle = [slice(0, s, 1) for s in u.shape] 

428 shuffle[axis] = self.fft_utils['fwd']['shuffle'] 

429 

430 v = u[(*shuffle,)] 

431 

432 V = self.fft_lib.fft(v, axis=axis, **kwargs) 

433 

434 expansion = [np.newaxis for _ in u.shape] 

435 expansion[axis] = slice(0, u.shape[axis], 1) 

436 

437 V *= self.fft_utils['fwd']['shift'][(*expansion,)] 

438 

439 result.real[...] = V.real[...] 

440 return result 

441 else: 

442 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}') 

443 

444 def itransform(self, u, axis=-1): 

445 """ 

446 1D inverse DCT along axis. 

447 

448 Args: 

449 u: Data you want to transform 

450 axis (int): Axis you want to transform along 

451 

452 Returns: 

453 Data in physical space 

454 """ 

455 assert self.norm.shape[0] == u.shape[axis] 

456 

457 if self.transform_type == 'dct': 

458 return self.fft_lib.idct(u / self.norm, axis=axis) 

459 elif self.transform_type == 'fft': 

460 result = u.copy() 

461 

462 expansion = [np.newaxis for _ in u.shape] 

463 expansion[axis] = slice(0, u.shape[axis], 1) 

464 

465 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis) 

466 

467 shuffle = [slice(0, s, 1) for s in u.shape] 

468 shuffle[axis] = self.fft_utils['bck']['shuffle'] 

469 V = v[(*shuffle,)] 

470 

471 result.real[...] = V.real[...] 

472 return result 

473 else: 

474 raise NotImplementedError 

475 

476 def get_BC(self, kind, **kwargs): 

477 """ 

478 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

479 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

480 set the BC. 

481 

482 Args: 

483 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

484 """ 

485 if kind.lower() == 'integral': 

486 return self.get_integ_BC_row(**kwargs) 

487 elif kind.lower() == 'dirichlet': 

488 return self.get_Dirichlet_BC_row(**kwargs) 

489 else: 

490 return super().get_BC(kind) 

491 

492 def get_integ_BC_row(self): 

493 """ 

494 Get a row for generating integral BCs with T polynomials. 

495 It returns the values of the integrals of T polynomials over the entire interval. 

496 

497 Returns: 

498 self.xp.ndarray: Row to put into a matrix 

499 """ 

500 n = self.xp.arange(self.N) + 1 

501 me = self.xp.zeros_like(n).astype(float) 

502 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

503 me[0] = 2.0 

504 return me 

505 

506 def get_Dirichlet_BC_row(self, x): 

507 """ 

508 Get a row for generating Dirichlet BCs at x with T polynomials. 

509 It returns the values of the T polynomials at x. 

510 

511 Args: 

512 x (float): Position of the boundary condition 

513 

514 Returns: 

515 self.xp.ndarray: Row to put into a matrix 

516 """ 

517 if x == -1: 

518 return (-1) ** self.xp.arange(self.N) 

519 elif x == 1: 

520 return self.xp.ones(self.N) 

521 elif x == 0: 

522 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

523 n[2::4] *= -1 

524 return n 

525 else: 

526 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

527 

528 def get_Dirichlet_recombination_matrix(self): 

529 ''' 

530 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

531 This makes for a good right preconditioner. 

532 

533 Returns: 

534 scipy.sparse: Sparse conversion matrix 

535 ''' 

536 N = self.N 

537 sp = self.sparse_lib 

538 xp = self.xp 

539 

540 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

541 

542 

543class UltrasphericalHelper(ChebychevHelper): 

544 """ 

545 This implementation follows https://doi.org/10.1137/120865458. 

546 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

547 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

548 There are also conversion operators from one derivative basis to the next that are sparse. 

549 

550 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

551 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

552 """ 

553 

554 def get_differentiation_matrix(self, p=1): 

555 """ 

556 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

557 

558 Args: 

559 p (int): Order of the derivative 

560 

561 Returns: 

562 sparse differentiation matrix 

563 """ 

564 sp = self.sparse_lib 

565 xp = self.xp 

566 N = self.N 

567 l = p 

568 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) 

569 

570 def get_S(self, lmbda): 

571 """ 

572 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

573 https://doi.org/10.1137/120865458. 

574 

575 Args: 

576 lmbda (int): Ingoing derivative base 

577 

578 Returns: 

579 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

580 """ 

581 N = self.N 

582 

583 if lmbda == 0: 

584 sp = scipy.sparse 

585 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil() 

586 mat[:, 0] *= 2 

587 else: 

588 sp = self.sparse_lib 

589 xp = self.xp 

590 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

591 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

592 ) 

593 

594 return self.sparse_lib.csc_matrix(mat) 

595 

596 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

597 """ 

598 Get a conversion matrix from derivative base `p_in` to `p_out`. 

599 

600 Args: 

601 p_out (int): Resulting derivative base 

602 p_in (int): Ingoing derivative base 

603 """ 

604 mat_fwd = self.sparse_lib.eye(self.N) 

605 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

606 mat_fwd = self.get_S(i) @ mat_fwd 

607 

608 if p_out > p_in: 

609 return mat_fwd 

610 

611 else: 

612 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

613 import scipy.sparse as sp 

614 

615 if self.useGPU: 

616 mat_fwd = mat_fwd.get() 

617 

618 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

619 

620 return self.sparse_lib.csc_matrix(mat_bck) 

621 

622 def get_integration_matrix(self): 

623 """ 

624 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

625 integration constant such that integration starts from x=-1. 

626 

627 Example: 

628 

629 .. code-block:: python 

630 

631 import numpy as np 

632 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

633 

634 N = 4 

635 helper = UltrasphericalHelper(N) 

636 coeffs = np.random.random(N) 

637 coeffs[-1] = 0 

638 

639 poly = np.polynomial.Chebyshev(coeffs) 

640 

641 S = helper.get_integration_matrix() 

642 U_hat = S @ coeffs 

643 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

644 

645 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

646 

647 Returns: 

648 sparse integration matrix 

649 """ 

650 return self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_basis_change_matrix( 

651 p_out=1, p_in=0 

652 ) 

653 

654 def get_integration_constant(self, u_hat, axis): 

655 """ 

656 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

657 

658 Args: 

659 u_hat: Solution in spectral space 

660 axis: Axis you want to integrate over 

661 

662 Returns: 

663 Integration constant, has one less dimension than `u_hat` 

664 """ 

665 slices = [ 

666 None, 

667 ] * u_hat.ndim 

668 slices[axis] = slice(1, u_hat.shape[axis]) 

669 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

670 

671 

672class FFTHelper(SpectralHelper1D): 

673 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

674 """ 

675 Constructor. 

676 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

677 may choose to run on GPUs via the `useGPU` argument. 

678 

679 Args: 

680 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

681 use the FFT from the library to compute the DCT 

682 x0 (float, optional): Coordinate of left boundary 

683 x1 (float, optional): Coordinate of right boundary 

684 """ 

685 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

686 

687 def get_1dgrid(self): 

688 """ 

689 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

690 """ 

691 dx = self.L / self.N 

692 return self.xp.arange(self.N) * dx + self.x0 

693 

694 def get_wavenumbers(self): 

695 """ 

696 Be careful that this ordering is very unintuitive. 

697 """ 

698 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

699 

700 def get_differentiation_matrix(self, p=1): 

701 """ 

702 This matrix is diagonal, allowing to invert concurrently. 

703 

704 Args: 

705 p (int): Order of the derivative 

706 

707 Returns: 

708 sparse differentiation matrix 

709 """ 

710 k = self.get_wavenumbers() 

711 

712 if self.useGPU: 

713 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

714 import scipy.sparse as sp 

715 

716 D = self.sparse_lib.diags(1j * k).get() 

717 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p)) 

718 else: 

719 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

720 

721 def get_integration_matrix(self, p=1): 

722 """ 

723 Get integration matrix to compute `p`-th integral over the entire domain. 

724 

725 Args: 

726 p (int): Order of integral you want to compute 

727 

728 Returns: 

729 sparse integration matrix 

730 """ 

731 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

732 k[0] = 1j * self.L 

733 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

734 

735 def transform(self, u, axis=-1, **kwargs): 

736 """ 

737 1D FFT along axis. `kwargs` are passed on to the FFT library. 

738 

739 Args: 

740 u: Data you want to transform 

741 axis (int): Axis you want to transform along 

742 

743 Returns: 

744 transformed data 

745 """ 

746 return self.fft_lib.fft(u, axis=axis, **kwargs) 

747 

748 def itransform(self, u, axis=-1): 

749 """ 

750 Inverse 1D FFT. 

751 

752 Args: 

753 u: Data you want to transform 

754 axis (int): Axis you want to transform along 

755 

756 Returns: 

757 transformed data 

758 """ 

759 return self.fft_lib.ifft(u, axis=axis) 

760 

761 def get_BC(self, kind): 

762 """ 

763 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

764 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

765 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

766 

767 Args: 

768 kind ('integral' or 'nyquist'): Kind of BC 

769 

770 Returns: 

771 self.xp.ndarray: Boundary condition row 

772 """ 

773 if kind.lower() == 'integral': 

774 return self.get_integ_BC_row() 

775 elif kind.lower() == 'nyquist': 

776 assert ( 

777 self.N % 2 == 0 

778 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

779 BC = self.xp.zeros(self.N) 

780 BC[self.get_Nyquist_mode_index()] = 1 

781 return BC 

782 else: 

783 return super().get_BC(kind) 

784 

785 def get_Nyquist_mode_index(self): 

786 """ 

787 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

788 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

789 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

790 after. 

791 

792 Returns: 

793 int: Index of the Nyquist mode 

794 """ 

795 k = self.get_wavenumbers() 

796 Nyquist_mode = min(k) 

797 return self.xp.where(k == Nyquist_mode)[0][0] 

798 

799 def get_integ_BC_row(self): 

800 """ 

801 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

802 """ 

803 me = self.xp.zeros(self.N) 

804 me[0] = self.L / self.N 

805 return me 

806 

807 

808class SpectralHelper: 

809 """ 

810 This class has three functions: 

811 - Easily assemble matrices containing multiple equations 

812 - Direct product of 1D bases to solve problems in more dimensions 

813 - Distribute the FFTs to facilitate concurrency. 

814 

815 Attributes: 

816 comm (mpi4py.Intracomm): MPI communicator 

817 debug (bool): Perform additional checks at extra computational cost 

818 useGPU (bool): Whether to use GPUs 

819 axes (list): List of 1D bases 

820 components (list): List of strings of the names of components in the equations 

821 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

822 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

823 BCs (sparse matrix): Matrix containing only the BCs 

824 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

825 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

826 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

827 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

828 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

829 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

830 local_slice (slice): Local slice of the solution as in `mpi4py-fft` 

831 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

832 init (tuple): This is the same `init` that is used throughout the problem classes 

833 init_forward (tuple): This is the equivalent of `init` in spectral space 

834 """ 

835 

836 xp = np 

837 fft_lib = scipy.fft 

838 sparse_lib = scipy.sparse 

839 linalg = scipy.sparse.linalg 

840 dtype = mesh 

841 fft_backend = 'fftw' 

842 fft_comm_backend = 'MPI' 

843 

844 @classmethod 

845 def setup_GPU(cls): 

846 """switch to GPU modules""" 

847 import cupy as cp 

848 import cupyx.scipy.sparse as sparse_lib 

849 import cupyx.scipy.sparse.linalg as linalg 

850 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

851 

852 cls.xp = cp 

853 cls.sparse_lib = sparse_lib 

854 cls.linalg = linalg 

855 

856 cls.fft_backend = 'cupy' 

857 cls.fft_comm_backend = 'NCCL' 

858 

859 cls.dtype = cupy_mesh 

860 

861 def __init__(self, comm=None, useGPU=False, debug=False): 

862 """ 

863 Constructor 

864 

865 Args: 

866 comm (mpi4py.Intracomm): MPI communicator 

867 useGPU (bool): Whether to use GPUs 

868 debug (bool): Perform additional checks at extra computational cost 

869 """ 

870 self.comm = comm 

871 self.debug = debug 

872 self.useGPU = useGPU 

873 

874 if useGPU: 

875 self.setup_GPU() 

876 

877 self.axes = [] 

878 self.components = [] 

879 

880 self.full_BCs = [] 

881 self.BC_mat = None 

882 self.BCs = None 

883 

884 self.fft_cache = {} 

885 self.fft_dealias_shape_cache = {} 

886 

887 @property 

888 def u_init(self): 

889 """ 

890 Get empty data container in physical space 

891 """ 

892 return self.dtype(self.init) 

893 

894 @property 

895 def u_init_forward(self): 

896 """ 

897 Get empty data container in spectral space 

898 """ 

899 return self.dtype(self.init_forward) 

900 

901 @property 

902 def shape(self): 

903 """ 

904 Get shape of individual solution component 

905 """ 

906 return self.init[0][1:] 

907 

908 @property 

909 def ndim(self): 

910 return len(self.axes) 

911 

912 @property 

913 def ncomponents(self): 

914 return len(self.components) 

915 

916 @property 

917 def V(self): 

918 """ 

919 Get domain volume 

920 """ 

921 return np.prod([me.L for me in self.axes]) 

922 

923 def add_axis(self, base, *args, **kwargs): 

924 """ 

925 Add an axis to the domain by deciding on suitable 1D base. 

926 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

927 bases for possible arguments. 

928 

929 Args: 

930 base (str): 1D spectral method 

931 """ 

932 kwargs['useGPU'] = self.useGPU 

933 

934 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

935 kwargs['transform_type'] = kwargs.get('transform_type', 'fft') 

936 self.axes.append(ChebychevHelper(*args, **kwargs)) 

937 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

938 self.axes.append(FFTHelper(*args, **kwargs)) 

939 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

940 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

941 else: 

942 raise NotImplementedError(f'{base=!r} is not implemented!') 

943 self.axes[-1].xp = self.xp 

944 self.axes[-1].sparse_lib = self.sparse_lib 

945 

946 def add_component(self, name): 

947 """ 

948 Add solution component(s). 

949 

950 Args: 

951 name (str or list of strings): Name(s) of component(s) 

952 """ 

953 if type(name) in [list, tuple]: 

954 for me in name: 

955 self.add_component(me) 

956 elif type(name) in [str]: 

957 if name in self.components: 

958 raise Exception(f'{name=!r} is already added to this problem!') 

959 self.components.append(name) 

960 else: 

961 raise NotImplementedError 

962 

963 def index(self, name): 

964 """ 

965 Get the index of component `name`. 

966 

967 Args: 

968 name (str or list of strings): Name(s) of component(s) 

969 

970 Returns: 

971 int: Index of the component 

972 """ 

973 if type(name) in [str, int]: 

974 return self.components.index(name) 

975 elif type(name) in [list, tuple]: 

976 return (self.index(me) for me in name) 

977 else: 

978 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

979 

980 def get_empty_operator_matrix(self): 

981 """ 

982 Return a matrix of operators to be filled with the connections between the solution components. 

983 

984 Returns: 

985 list containing sparse zeros 

986 """ 

987 S = len(self.components) 

988 O = self.get_Id() * 0 

989 return [[O for _ in range(S)] for _ in range(S)] 

990 

991 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

992 """ 

993 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

994 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

995 Typically, the last line that does not contain a BC is the best choice. 

996 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

997 

998 Args: 

999 axis (int): Axis you want to add the BC to 

1000 kind (str): kind of BC, e.g. Dirichlet 

1001 line (int): Line you want the BC to go in 

1002 scalar (bool): Put the BC in all space positions in the other direction 

1003 

1004 Returns: 

1005 sparse matrix containing the BC 

1006 """ 

1007 sp = scipy.sparse 

1008 

1009 base = self.axes[axis] 

1010 

1011 BC = sp.eye(base.N).tolil() * 0 

1012 if self.useGPU: 

1013 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1014 else: 

1015 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1016 

1017 ndim = len(self.axes) 

1018 if ndim == 1: 

1019 return self.sparse_lib.csc_matrix(BC) 

1020 elif ndim == 2: 

1021 axis2 = (axis + 1) % ndim 

1022 

1023 if scalar: 

1024 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1025 else: 

1026 _Id = self.axes[axis2].get_Id() 

1027 

1028 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1029 

1030 if self.useGPU: 

1031 Id = Id.get() 

1032 

1033 mats = [ 

1034 None, 

1035 ] * ndim 

1036 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1037 mats[axis2] = Id 

1038 return self.sparse_lib.csc_matrix(sp.kron(*mats)) 

1039 

1040 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1041 """ 

1042 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1043 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1044 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1045 

1046 Args: 

1047 component (str): Name of the component the BC should act on 

1048 equation (str): Name of the equation for the component you want to put the BC in 

1049 axis (int): Axis you want to add the BC to 

1050 kind (str): kind of BC, e.g. Dirichlet 

1051 v: Value of the BC 

1052 line (int): Line you want the BC to go in 

1053 scalar (bool): Put the BC in all space positions in the other direction 

1054 """ 

1055 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1056 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1057 

1058 if scalar: 

1059 slices = [self.index(equation)] + [ 

1060 0, 

1061 ] * self.ndim 

1062 slices[axis + 1] = line 

1063 else: 

1064 slices = ( 

1065 [self.index(equation)] 

1066 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1067 + [line] 

1068 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1069 ) 

1070 N = self.axes[axis].N 

1071 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1072 self.BC_rhs_mask[(*slices,)] = False 

1073 

1074 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1075 """ 

1076 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1077 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1078 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1079 

1080 Args: 

1081 component (str): Name of the component the BC should act on 

1082 equation (str): Name of the equation for the component you want to put the BC in 

1083 axis (int): Axis you want to add the BC to 

1084 kind (str): kind of BC, e.g. Dirichlet 

1085 v: Value of the BC 

1086 line (int): Line you want the BC to go in 

1087 scalar (bool): Put the BC in all space positions in the other direction 

1088 """ 

1089 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1090 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1091 self.full_BCs += [ 

1092 { 

1093 'component': component, 

1094 'equation': equation, 

1095 'axis': axis, 

1096 'kind': kind, 

1097 'v': v, 

1098 'line': line, 

1099 'scalar': scalar, 

1100 **kwargs, 

1101 } 

1102 ] 

1103 

1104 if scalar: 

1105 slices = [self.index(equation)] + [ 

1106 0, 

1107 ] * self.ndim 

1108 slices[axis + 1] = line 

1109 if self.comm: 

1110 if self.comm.rank == 0: 

1111 self.BC_rhs_mask[(*slices,)] = True 

1112 else: 

1113 self.BC_rhs_mask[(*slices,)] = True 

1114 else: 

1115 slices = ( 

1116 [self.index(equation)] 

1117 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1118 + [line] 

1119 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1120 ) 

1121 N = self.axes[axis].N 

1122 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1123 slices[axis + 1] -= self.local_slice[axis].start 

1124 self.BC_rhs_mask[(*slices,)] = True 

1125 

1126 def setup_BCs(self): 

1127 """ 

1128 Convert the list of lists of BCs to the boundary condition operator. 

1129 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1130 condition. This method sets up a suitable sparse matrix to do this. 

1131 """ 

1132 sp = self.sparse_lib 

1133 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1134 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1135 

1136 diags = self.xp.ones(self.BCs.shape[0]) 

1137 diags[self.BC_zero_index] = 0 

1138 self.BC_line_zero_matrix = sp.diags(diags) 

1139 

1140 # prepare BCs in spectral space to easily add to the RHS 

1141 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1142 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1143 

1144 def check_BCs(self, u): 

1145 """ 

1146 Check that the solution satisfies the boundary conditions 

1147 

1148 Args: 

1149 u: The solution you want to check 

1150 """ 

1151 assert self.ndim < 3 

1152 for axis in range(self.ndim): 

1153 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1154 

1155 if len(BCs) > 0: 

1156 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1157 for BC in BCs: 

1158 kwargs = { 

1159 key: value 

1160 for key, value in BC.items() 

1161 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1162 } 

1163 

1164 if axis == 0: 

1165 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1166 elif axis == 1: 

1167 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1168 want = BC['v'] 

1169 assert self.xp.allclose( 

1170 get, want 

1171 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1172 

1173 def put_BCs_in_matrix(self, A): 

1174 """ 

1175 Put the boundary conditions in a matrix by replacing rows with BCs. 

1176 """ 

1177 return self.BC_line_zero_matrix @ A + self.BCs 

1178 

1179 def put_BCs_in_rhs_hat(self, rhs_hat): 

1180 """ 

1181 Put the BCs in the right hand side in spectral space for solving. 

1182 This function needs no transforms and caches a mask for faster subsequent use. 

1183 

1184 Args: 

1185 rhs_hat: Right hand side in spectral space 

1186 

1187 Returns: 

1188 rhs in spectral space with BCs 

1189 """ 

1190 if not hasattr(self, '_rhs_hat_zero_mask'): 

1191 """ 

1192 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1193 by the boundary conditions. The mask is then cached. 

1194 """ 

1195 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool) 

1196 

1197 for axis in range(self.ndim): 

1198 for bc in self.full_BCs: 

1199 slices = ( 

1200 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1201 + [bc['line']] 

1202 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1203 ) 

1204 if axis == bc['axis']: 

1205 _slice = [self.index(bc['equation'])] + slices 

1206 N = self.axes[axis].N 

1207 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1208 _slice[axis + 1] -= self.local_slice[axis].start 

1209 self._rhs_hat_zero_mask[(*_slice,)] = True 

1210 

1211 rhs_hat[self._rhs_hat_zero_mask] = 0 

1212 return rhs_hat + self.rhs_BCs_hat 

1213 

1214 def put_BCs_in_rhs(self, rhs): 

1215 """ 

1216 Put the BCs in the right hand side for solving. 

1217 This function will transform along each axis individually and add all BCs in that axis. 

1218 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1219 

1220 Args: 

1221 rhs: Right hand side in physical space 

1222 

1223 Returns: 

1224 rhs in physical space with BCs 

1225 """ 

1226 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1227 

1228 ndim = self.ndim 

1229 

1230 for axis in range(ndim): 

1231 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1232 

1233 for bc in self.full_BCs: 

1234 slices = ( 

1235 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1236 + [bc['line']] 

1237 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1238 ) 

1239 if axis == bc['axis']: 

1240 _slice = [self.index(bc['equation'])] + slices 

1241 

1242 N = self.axes[axis].N 

1243 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1244 _slice[axis + 1] -= self.local_slice[axis].start 

1245 

1246 _rhs_hat[(*_slice,)] = bc['v'] 

1247 

1248 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1249 

1250 return rhs 

1251 

1252 def add_equation_lhs(self, A, equation, relations): 

1253 """ 

1254 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1255 that you will convert to an operator later. 

1256 

1257 Example: 

1258 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1259 preconditioning: 

1260 

1261 .. code-block:: python 

1262 helper = SpectralHelper() 

1263 

1264 helper.add_axis(base='chebychev', N=8) 

1265 helper.add_component(['u', 'ux']) 

1266 helper.setup_fft() 

1267 

1268 I = helper.get_Id() 

1269 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1270 T2U = helper.get_basis_change_matrix('T2U') 

1271 

1272 L_lhs = { 

1273 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1274 'u': {'ux': -(T2U @ Dx)}, 

1275 } 

1276 

1277 operator = helper.get_empty_operator_matrix() 

1278 for line, equation in L_lhs.items(): 

1279 helper.add_equation_lhs(operator, line, equation) 

1280 

1281 L = helper.convert_operator_matrix_to_operator(operator) 

1282 

1283 Args: 

1284 A (list of lists of sparse matrices): The operator to be 

1285 equation (str): The equation of the component you want this in 

1286 relations: (dict): Relations between quantities 

1287 """ 

1288 for k, v in relations.items(): 

1289 A[self.index(equation)][self.index(k)] = v 

1290 

1291 def convert_operator_matrix_to_operator(self, M): 

1292 """ 

1293 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1294 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1295 

1296 Args: 

1297 M (list of lists of sparse matrices): The operator to be 

1298 

1299 Returns: 

1300 sparse linear operator 

1301 """ 

1302 if len(self.components) == 1: 

1303 return M[0][0] 

1304 else: 

1305 return self.sparse_lib.bmat(M, format='csc') 

1306 

1307 def get_wavenumbers(self): 

1308 """ 

1309 Get grid in spectral space 

1310 """ 

1311 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))][::-1] 

1312 return self.xp.meshgrid(*grids) 

1313 

1314 def get_grid(self): 

1315 """ 

1316 Get grid in physical space 

1317 """ 

1318 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))][::-1] 

1319 return self.xp.meshgrid(*grids) 

1320 

1321 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1322 """ 

1323 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1324 

1325 Args: 

1326 axes (tuple): Axes you want to transform over 

1327 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1328 padding (tuple): Padding for dealiasing 

1329 shape (tuple): Shape of the transform 

1330 

1331 Returns: 

1332 transform 

1333 """ 

1334 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1335 shape = self.global_shape[1:] if shape is None else shape 

1336 padding = ( 

1337 [ 

1338 1, 

1339 ] 

1340 * self.ndim 

1341 if padding is None 

1342 else padding 

1343 ) 

1344 key = (axes, direction, tuple(padding), tuple(shape)) 

1345 

1346 if key not in self.fft_cache.keys(): 

1347 if self.comm is None: 

1348 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1349 

1350 if direction == 'forward': 

1351 self.fft_cache[key] = self.xp.fft.fftn 

1352 elif direction == 'backward': 

1353 self.fft_cache[key] = self.xp.fft.ifftn 

1354 elif direction == 'object': 

1355 self.fft_cache[key] = None 

1356 else: 

1357 if direction == 'object': 

1358 from mpi4py_fft import PFFT 

1359 

1360 _fft = PFFT( 

1361 comm=self.comm, 

1362 shape=shape, 

1363 axes=sorted(axes), 

1364 dtype='D', 

1365 collapse=False, 

1366 backend=self.fft_backend, 

1367 comm_backend=self.fft_comm_backend, 

1368 padding=padding, 

1369 ) 

1370 else: 

1371 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1372 

1373 if direction == 'forward': 

1374 self.fft_cache[key] = _fft.forward 

1375 elif direction == 'backward': 

1376 self.fft_cache[key] = _fft.backward 

1377 elif direction == 'object': 

1378 self.fft_cache[key] = _fft 

1379 

1380 return self.fft_cache[key] 

1381 

1382 def setup_fft(self, real_spectral_coefficients=False): 

1383 """ 

1384 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1385 This must also be called before setting up any BCs. 

1386 

1387 Args: 

1388 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1389 """ 

1390 if len(self.components) == 0: 

1391 self.add_component('u') 

1392 

1393 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1394 self.local_slice = [slice(0, me.N) for me in self.axes] 

1395 

1396 axes = tuple(i for i in range(len(self.axes))) 

1397 self.fft_obj = self.get_fft(axes=axes, direction='object') 

1398 if self.fft_obj is not None: 

1399 self.local_slice = self.fft_obj.local_slice(False) 

1400 

1401 self.init = ( 

1402 np.empty(shape=self.global_shape)[ 

1403 ( 

1404 ..., 

1405 *self.local_slice, 

1406 ) 

1407 ].shape, 

1408 self.comm, 

1409 np.dtype('float'), 

1410 ) 

1411 self.init_forward = ( 

1412 np.empty(shape=self.global_shape)[ 

1413 ( 

1414 ..., 

1415 *self.local_slice, 

1416 ) 

1417 ].shape, 

1418 self.comm, 

1419 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1420 ) 

1421 

1422 self.BC_mat = self.get_empty_operator_matrix() 

1423 self.BC_rhs_mask = self.xp.zeros( 

1424 shape=self.init[0], 

1425 dtype=bool, 

1426 ) 

1427 

1428 def _transform_fft(self, u, axes, **kwargs): 

1429 """ 

1430 FFT along `axes` 

1431 

1432 Args: 

1433 u: The solution 

1434 axes (tuple): Axes you want to transform over 

1435 

1436 Returns: 

1437 transformed solution 

1438 """ 

1439 # TODO: clean up and try putting more of this in the 1D bases 

1440 fft = self.get_fft(axes, 'forward', **kwargs) 

1441 return fft(u, axes=axes) 

1442 

1443 def _transform_dct(self, u, axes, padding=None, **kwargs): 

1444 ''' 

1445 DCT along `axes`. 

1446 This will only return real values! 

1447 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of 

1448 wavenumbers in FFTs. 

1449 

1450 Args: 

1451 u: The solution 

1452 axes (tuple): Axes you want to transform over 

1453 

1454 Returns: 

1455 transformed solution 

1456 ''' 

1457 # TODO: clean up and try putting more of this in the 1D bases 

1458 if self.debug: 

1459 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1460 

1461 if len(axes) > 1: 

1462 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs) 

1463 else: 

1464 v = u.copy().astype(complex) 

1465 axis = axes[0] 

1466 base = self.axes[axis] 

1467 

1468 shuffle = [slice(0, s, 1) for s in u.shape] 

1469 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis]) 

1470 v = v[(*shuffle,)] 

1471 

1472 if padding is not None: 

1473 shape = list(v.shape) 

1474 if ('forward', *padding) in self.fft_dealias_shape_cache.keys(): 

1475 shape[0] = self.fft_dealias_shape_cache[('forward', *padding)] 

1476 elif self.comm: 

1477 send_buf = np.array(v.shape[0]) 

1478 recv_buf = np.array(v.shape[0]) 

1479 self.comm.Allreduce(send_buf, recv_buf) 

1480 shape[0] = int(recv_buf) 

1481 fft = self.get_fft(axes, 'forward', shape=shape) 

1482 else: 

1483 fft = self.get_fft(axes, 'forward', **kwargs) 

1484 

1485 v = fft(v, axes=axes) 

1486 

1487 expansion = [np.newaxis for _ in u.shape] 

1488 expansion[axis] = slice(0, v.shape[axis], 1) 

1489 

1490 if padding is not None: 

1491 shift = base.get_fft_shift(True, v.shape[axis]) 

1492 

1493 if padding[axis] != 1: 

1494 N = int(np.ceil(v.shape[axis] / padding[axis])) 

1495 _expansion = [slice(0, n) for n in v.shape] 

1496 _expansion[axis] = slice(0, N, 1) 

1497 v = v[(*_expansion,)] 

1498 else: 

1499 shift = base.fft_utils['fwd']['shift'] 

1500 

1501 v *= shift[(*expansion,)] 

1502 

1503 return v.real 

1504 

1505 def transform_single_component(self, u, axes=None, padding=None): 

1506 """ 

1507 Transform a single component of the solution 

1508 

1509 Args: 

1510 u data to transform: 

1511 axes (tuple): Axes over which to transform 

1512 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1513 

1514 Returns: 

1515 Transformed data 

1516 """ 

1517 # TODO: clean up and try putting more of this in the 1D bases 

1518 trfs = { 

1519 ChebychevHelper: self._transform_dct, 

1520 UltrasphericalHelper: self._transform_dct, 

1521 FFTHelper: self._transform_fft, 

1522 } 

1523 

1524 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1525 padding = ( 

1526 [ 

1527 1, 

1528 ] 

1529 * self.ndim 

1530 if padding is None 

1531 else padding 

1532 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only. 

1533 

1534 result = u.copy().astype(complex) 

1535 alignment = self.ndim - 1 

1536 

1537 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1538 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1539 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1540 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)] 

1541 

1542 fft = self.get_fft(axes=axes, padding=padding, direction='object') 

1543 if fft is not None: 

1544 shape = list(fft.global_shape(False)) 

1545 

1546 for trf in range(len(axes_collapsed)): 

1547 _axes = axes_collapsed[trf] 

1548 base = bases[trf] 

1549 

1550 if len(_axes) == 0: 

1551 continue 

1552 

1553 for _ax in _axes: 

1554 shape[_ax] = self.global_shape[1 + self.ndim + _ax] 

1555 

1556 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1557 

1558 _in = self.get_aligned( 

1559 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape 

1560 ) 

1561 

1562 alignment = self.ndim + _axes[-1] 

1563 

1564 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1565 

1566 if self.comm is not None: 

1567 _out *= np.prod([self.axes[i].N for i in _axes]) 

1568 

1569 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1570 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1] 

1571 result = self.get_aligned( 

1572 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape 

1573 ) 

1574 

1575 return result 

1576 

1577 def transform(self, u, axes=None, padding=None): 

1578 """ 

1579 Transform all components from physical space to spectral space 

1580 

1581 Args: 

1582 u data to transform: 

1583 axes (tuple): Axes over which to transform 

1584 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1585 

1586 Returns: 

1587 Transformed data 

1588 """ 

1589 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1590 padding = ( 

1591 [ 

1592 1, 

1593 ] 

1594 * self.ndim 

1595 if padding is None 

1596 else padding 

1597 ) 

1598 

1599 result = [ 

1600 None, 

1601 ] * self.ncomponents 

1602 for comp in self.components: 

1603 i = self.index(comp) 

1604 

1605 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding) 

1606 

1607 return self.xp.stack(result) 

1608 

1609 def _transform_ifft(self, u, axes, **kwargs): 

1610 # TODO: clean up and try putting more of this in the 1D bases 

1611 ifft = self.get_fft(axes, 'backward', **kwargs) 

1612 return ifft(u, axes=axes) 

1613 

1614 def _transform_idct(self, u, axes, padding=None, **kwargs): 

1615 ''' 

1616 This will only ever return real values! 

1617 ''' 

1618 # TODO: clean up and try putting more of this in the 1D bases 

1619 if self.debug: 

1620 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1621 

1622 v = u.copy().astype(complex) 

1623 

1624 if len(axes) > 1: 

1625 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],)) 

1626 else: 

1627 axis = axes[0] 

1628 base = self.axes[axis] 

1629 

1630 if padding is not None: 

1631 if padding[axis] != 1: 

1632 N_pad = int(np.ceil(v.shape[axis] * padding[axis])) 

1633 _pad = [[0, 0] for _ in v.shape] 

1634 _pad[axis] = [0, N_pad - base.N] 

1635 v = self.xp.pad(v, _pad, 'constant') 

1636 

1637 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N 

1638 else: 

1639 shift = base.fft_utils['bck']['shift'] 

1640 else: 

1641 shift = base.fft_utils['bck']['shift'] 

1642 

1643 expansion = [np.newaxis for _ in u.shape] 

1644 expansion[axis] = slice(0, v.shape[axis], 1) 

1645 

1646 v *= shift[(*expansion,)] 

1647 

1648 if padding is not None: 

1649 if padding[axis] != 1: 

1650 shape = list(v.shape) 

1651 if ('backward', *padding) in self.fft_dealias_shape_cache.keys(): 

1652 shape[0] = self.fft_dealias_shape_cache[('backward', *padding)] 

1653 elif self.comm: 

1654 send_buf = np.array(v.shape[0]) 

1655 recv_buf = np.array(v.shape[0]) 

1656 self.comm.Allreduce(send_buf, recv_buf) 

1657 shape[0] = int(recv_buf) 

1658 ifft = self.get_fft(axes, 'backward', shape=shape) 

1659 else: 

1660 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1661 else: 

1662 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1663 v = ifft(v, axes=axes) 

1664 

1665 shuffle = [slice(0, s, 1) for s in v.shape] 

1666 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis]) 

1667 v = v[(*shuffle,)] 

1668 

1669 return v.real 

1670 

1671 def itransform_single_component(self, u, axes=None, padding=None): 

1672 """ 

1673 Inverse transform over single component of the solution 

1674 

1675 Args: 

1676 u data to transform: 

1677 axes (tuple): Axes over which to transform 

1678 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1679 

1680 Returns: 

1681 Transformed data 

1682 """ 

1683 # TODO: clean up and try putting more of this in the 1D bases 

1684 trfs = { 

1685 FFTHelper: self._transform_ifft, 

1686 ChebychevHelper: self._transform_idct, 

1687 UltrasphericalHelper: self._transform_idct, 

1688 } 

1689 

1690 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1691 padding = ( 

1692 [ 

1693 1, 

1694 ] 

1695 * self.ndim 

1696 if padding is None 

1697 else padding 

1698 ) 

1699 

1700 result = u.copy().astype(complex) 

1701 alignment = self.ndim - 1 

1702 

1703 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1704 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1705 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1706 shape = list(self.global_shape[1:]) 

1707 

1708 for trf in range(len(axes_collapsed)): 

1709 _axes = axes_collapsed[trf] 

1710 base = bases[trf] 

1711 

1712 if len(_axes) == 0: 

1713 continue 

1714 

1715 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1716 

1717 _in = self.get_aligned( 

1718 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape 

1719 ) 

1720 if self.comm is not None: 

1721 _in /= np.prod([self.axes[i].N for i in _axes]) 

1722 

1723 alignment = self.ndim + _axes[0] 

1724 

1725 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1726 

1727 for _ax in _axes: 

1728 if fft: 

1729 shape[_ax] = fft._input_shape[_ax] 

1730 else: 

1731 shape[_ax] = _out.shape[_ax] 

1732 

1733 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1734 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0] 

1735 result = self.get_aligned( 

1736 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape 

1737 ) 

1738 

1739 return result 

1740 

1741 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs): 

1742 """ 

1743 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT 

1744 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include 

1745 shape especially, if applicable. 

1746 

1747 Args: 

1748 u: The solution 

1749 axis_in (int): Current alignment 

1750 axis_out (int): New alignment 

1751 fft (mpi4py_fft.PFFT), optional: parallel FFT object 

1752 forward (bool): Whether the input is in spectral space or not 

1753 

1754 Returns: 

1755 solution aligned on `axis_in` 

1756 """ 

1757 if self.comm is None or axis_in == axis_out: 

1758 return u.copy() 

1759 if self.comm.size == 1: 

1760 return u.copy() 

1761 

1762 global_fft = self.get_fft(**kwargs) 

1763 axisA = [me.axisA for me in global_fft.transfer] 

1764 axisB = [me.axisB for me in global_fft.transfer] 

1765 

1766 current_axis = axis_in 

1767 

1768 if axis_in in axisA and axis_out in axisB: 

1769 while current_axis != axis_out: 

1770 transfer = global_fft.transfer[axisA.index(current_axis)] 

1771 

1772 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1773 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1774 arrayA[:] = u[:] 

1775 

1776 transfer.forward(arrayA=arrayA, arrayB=arrayB) 

1777 

1778 current_axis = transfer.axisB 

1779 u = arrayB 

1780 

1781 return u 

1782 elif axis_in in axisB and axis_out in axisA: 

1783 while current_axis != axis_out: 

1784 transfer = global_fft.transfer[axisB.index(current_axis)] 

1785 

1786 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1787 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1788 arrayB[:] = u[:] 

1789 

1790 transfer.backward(arrayA=arrayA, arrayB=arrayB) 

1791 

1792 current_axis = transfer.axisA 

1793 u = arrayA 

1794 

1795 return u 

1796 else: # go the potentially slower route of not reusing transfer classes 

1797 from mpi4py_fft import newDistArray 

1798 

1799 fft = self.get_fft(**kwargs) if fft is None else fft 

1800 

1801 _in = newDistArray(fft, forward).redistribute(axis_in) 

1802 _in[...] = u 

1803 

1804 return _in.redistribute(axis_out) 

1805 

1806 def itransform(self, u, axes=None, padding=None): 

1807 """ 

1808 Inverse transform over all components of the solution 

1809 

1810 Args: 

1811 u data to transform: 

1812 axes (tuple): Axes over which to transform 

1813 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1814 

1815 Returns: 

1816 Transformed data 

1817 """ 

1818 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1819 padding = ( 

1820 [ 

1821 1, 

1822 ] 

1823 * self.ndim 

1824 if padding is None 

1825 else padding 

1826 ) 

1827 

1828 result = [ 

1829 None, 

1830 ] * self.ncomponents 

1831 for comp in self.components: 

1832 i = self.index(comp) 

1833 

1834 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding) 

1835 

1836 return self.xp.stack(result) 

1837 

1838 def get_local_slice_of_1D_matrix(self, M, axis): 

1839 """ 

1840 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1841 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can 

1842 use this method to get the part corresponding to the modes carried by this rank. 

1843 

1844 Args: 

1845 M (sparse matrix): Global 1D matrix you want to get the local version of 

1846 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only. 

1847 

1848 Returns: 

1849 sparse local matrix 

1850 """ 

1851 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]] 

1852 

1853 def get_filter_matrix(self, axis, **kwargs): 

1854 """ 

1855 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1856 admissible. 

1857 

1858 Returns: 

1859 sparse bandpass matrix 

1860 """ 

1861 if self.ndim == 1: 

1862 return self.axes[0].get_filter_matrix(**kwargs) 

1863 

1864 mats = [base.get_Id() for base in self.axes] 

1865 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1866 return self.sparse_lib.kron(*mats) 

1867 

1868 def get_differentiation_matrix(self, axes, **kwargs): 

1869 """ 

1870 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1871 

1872 Args: 

1873 axes (tuple): Axes along which to differentiate. 

1874 

1875 Returns: 

1876 sparse differentiation matrix 

1877 """ 

1878 sp = self.sparse_lib 

1879 ndim = self.ndim 

1880 

1881 if ndim == 1: 

1882 D = self.axes[0].get_differentiation_matrix(**kwargs) 

1883 elif ndim == 2: 

1884 for axis in axes: 

1885 axis2 = (axis + 1) % ndim 

1886 D1D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1887 

1888 if len(axes) > 1: 

1889 I1D = sp.eye(self.axes[axis2].N) 

1890 else: 

1891 I1D = self.axes[axis2].get_Id() 

1892 

1893 mats = [None] * ndim 

1894 mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis) 

1895 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1896 

1897 if axis == axes[0]: 

1898 D = sp.kron(*mats) 

1899 else: 

1900 D = D @ sp.kron(*mats) 

1901 else: 

1902 raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!') 

1903 

1904 return D 

1905 

1906 def get_integration_matrix(self, axes): 

1907 """ 

1908 Get integration matrix to integrate along specified axis. 

1909 

1910 Args: 

1911 axes (tuple): Axes along which to integrate over. 

1912 

1913 Returns: 

1914 sparse integration matrix 

1915 """ 

1916 sp = self.sparse_lib 

1917 ndim = len(self.axes) 

1918 

1919 if ndim == 1: 

1920 S = self.axes[0].get_integration_matrix() 

1921 elif ndim == 2: 

1922 for axis in axes: 

1923 axis2 = (axis + 1) % ndim 

1924 S1D = self.axes[axis].get_integration_matrix() 

1925 

1926 if len(axes) > 1: 

1927 I1D = sp.eye(self.axes[axis2].N) 

1928 else: 

1929 I1D = self.axes[axis2].get_Id() 

1930 

1931 mats = [None] * ndim 

1932 mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis) 

1933 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1934 

1935 if axis == axes[0]: 

1936 S = sp.kron(*mats) 

1937 else: 

1938 S = S @ sp.kron(*mats) 

1939 else: 

1940 raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!') 

1941 

1942 return S 

1943 

1944 def get_Id(self): 

1945 """ 

1946 Get identity matrix 

1947 

1948 Returns: 

1949 sparse identity matrix 

1950 """ 

1951 sp = self.sparse_lib 

1952 ndim = self.ndim 

1953 I = sp.eye(np.prod(self.init[0][1:]), dtype=complex) 

1954 

1955 if ndim == 1: 

1956 I = self.axes[0].get_Id() 

1957 elif ndim == 2: 

1958 for axis in range(ndim): 

1959 axis2 = (axis + 1) % ndim 

1960 I1D = self.axes[axis].get_Id() 

1961 

1962 I1D2 = sp.eye(self.axes[axis2].N) 

1963 

1964 mats = [None] * ndim 

1965 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1966 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2) 

1967 

1968 I = I @ sp.kron(*mats) 

1969 else: 

1970 raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!') 

1971 

1972 return I 

1973 

1974 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1975 """ 

1976 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1977 

1978 Args: 

1979 axis (int): Axis you discretized with Chebychev 

1980 

1981 Returns: 

1982 sparse matrix 

1983 """ 

1984 sp = self.sparse_lib 

1985 ndim = len(self.axes) 

1986 

1987 if ndim == 1: 

1988 C = self.axes[0].get_Dirichlet_recombination_matrix() 

1989 elif ndim == 2: 

1990 axis2 = (axis + 1) % ndim 

1991 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1992 

1993 I1D = self.axes[axis2].get_Id() 

1994 

1995 mats = [None] * ndim 

1996 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) 

1997 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1998 

1999 C = sp.kron(*mats) 

2000 else: 

2001 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') 

2002 

2003 return C 

2004 

2005 def get_basis_change_matrix(self, axes=None, **kwargs): 

2006 """ 

2007 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

2008 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

2009 

2010 Args: 

2011 axes (tuple): Axes along which to change basis. 

2012 

2013 Returns: 

2014 sparse basis change matrix 

2015 """ 

2016 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

2017 

2018 sp = self.sparse_lib 

2019 ndim = len(self.axes) 

2020 

2021 if ndim == 1: 

2022 C = self.axes[0].get_basis_change_matrix(**kwargs) 

2023 elif ndim == 2: 

2024 for axis in axes: 

2025 axis2 = (axis + 1) % ndim 

2026 C1D = self.axes[axis].get_basis_change_matrix(**kwargs) 

2027 

2028 if len(axes) > 1: 

2029 I1D = sp.eye(self.axes[axis2].N) 

2030 else: 

2031 I1D = self.axes[axis2].get_Id() 

2032 

2033 mats = [None] * ndim 

2034 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) 

2035 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

2036 

2037 if axis == axes[0]: 

2038 C = sp.kron(*mats) 

2039 else: 

2040 C = C @ sp.kron(*mats) 

2041 else: 

2042 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') 

2043 

2044 return C