Coverage for pySDC/helpers/spectral_helper.py: 92%

778 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 13:12 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5 

6 

7class SpectralHelper1D: 

8 """ 

9 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

10 all bases need to have. 

11 

12 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

13 the code for GPUs. 

14 

15 Attributes: 

16 N (int): Resolution 

17 x0 (float): Coordinate of left boundary 

18 x1 (float): Coordinate of right boundary 

19 L (float): Length of the domain 

20 useGPU (bool): Whether to use GPUs 

21 

22 """ 

23 

24 fft_lib = scipy.fft 

25 sparse_lib = scipy.sparse 

26 linalg = scipy.sparse.linalg 

27 xp = np 

28 

29 def __init__(self, N, x0=None, x1=None, useGPU=False): 

30 """ 

31 Constructor 

32 

33 Args: 

34 N (int): Resolution 

35 x0 (float): Coordinate of left boundary 

36 x1 (float): Coordinate of right boundary 

37 useGPU (bool): Whether to use GPUs 

38 """ 

39 self.N = N 

40 self.x0 = x0 

41 self.x1 = x1 

42 self.L = x1 - x0 

43 self.useGPU = useGPU 

44 

45 if useGPU: 

46 self.setup_GPU() 

47 

48 @classmethod 

49 def setup_GPU(cls): 

50 """switch to GPU modules""" 

51 import cupy as cp 

52 import cupyx.scipy.sparse as sparse_lib 

53 import cupyx.scipy.sparse.linalg as linalg 

54 import cupyx.scipy.fft as fft_lib 

55 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

56 

57 cls.xp = cp 

58 cls.sparse_lib = sparse_lib 

59 cls.linalg = linalg 

60 cls.fft_lib = fft_lib 

61 

62 def get_Id(self): 

63 """ 

64 Get identity matrix 

65 

66 Returns: 

67 sparse diagonal identity matrix 

68 """ 

69 return self.sparse_lib.eye(self.N) 

70 

71 def get_zero(self): 

72 """ 

73 Get a matrix with all zeros of the correct size. 

74 

75 Returns: 

76 sparse matrix with zeros everywhere 

77 """ 

78 return 0 * self.get_Id() 

79 

80 def get_differentiation_matrix(self): 

81 raise NotImplementedError() 

82 

83 def get_integration_matrix(self): 

84 raise NotImplementedError() 

85 

86 def get_wavenumbers(self): 

87 """ 

88 Get the grid in spectral space 

89 """ 

90 raise NotImplementedError 

91 

92 def get_empty_operator_matrix(self, S, O): 

93 """ 

94 Return a matrix of operators to be filled with the connections between the solution components. 

95 

96 Args: 

97 S (int): Number of components in the solution 

98 O (sparse matrix): Zero matrix used for initialization 

99 

100 Returns: 

101 list of lists containing sparse zeros 

102 """ 

103 return [[O for _ in range(S)] for _ in range(S)] 

104 

105 def get_basis_change_matrix(self, *args, **kwargs): 

106 """ 

107 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

108 between the various bases. 

109 

110 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

111 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

112 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

113 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

114 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

115 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

116 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

117 

118 Returns: 

119 sparse bases change matrix 

120 """ 

121 return self.sparse_lib.eye(self.N) 

122 

123 def get_BC(self, kind): 

124 """ 

125 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

126 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

127 the matrix are replaced by the boundary conditions as obtained by this method. 

128 

129 Args: 

130 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

131 individual 1D bases for what is implemented 

132 

133 Returns: 

134 self.xp.array: Boundary condition 

135 """ 

136 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

137 

138 def get_filter_matrix(self, kmin=0, kmax=None): 

139 """ 

140 Get a bandpass filter. 

141 

142 Args: 

143 kmin (int): Lower limit of the bandpass filter 

144 kmax (int): Upper limit of the bandpass filter 

145 

146 Returns: 

147 sparse matrix 

148 """ 

149 

150 k = abs(self.get_wavenumbers()) 

151 

152 kmax = max(k) if kmax is None else kmax 

153 

154 mask = self.xp.logical_or(k >= kmax, k < kmin) 

155 

156 if self.useGPU: 

157 Id = self.get_Id().get() 

158 else: 

159 Id = self.get_Id() 

160 F = Id.tolil() 

161 F[:, mask] = 0 

162 return F.tocsc() 

163 

164 def get_1dgrid(self): 

165 """ 

166 Get the grid in physical space 

167 

168 Returns: 

169 self.xp.array: Grid 

170 """ 

171 raise NotImplementedError 

172 

173 

174class ChebychevHelper(SpectralHelper1D): 

175 """ 

176 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

177 between physical and spectral space by discrete cosine transform. 

178 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

179 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

180 be formulated in first order formulation. 

181 

182 This implementation is largely based on the Dedalus paper (arXiv:1905.10388). 

183 """ 

184 

185 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs): 

186 """ 

187 Constructor. 

188 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

189 may choose to run on GPUs via the `useGPU` argument. 

190 

191 Args: 

192 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

193 use the FFT from the library to compute the DCT 

194 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

195 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

196 """ 

197 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

198 self.lin_trf_fac = (x1 - x0) / 2 

199 self.lin_trf_off = (x1 + x0) / 2 

200 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

201 self.transform_type = transform_type 

202 

203 if self.transform_type == 'fft': 

204 self.get_fft_utils() 

205 

206 self.cache = {} 

207 self.norm = self.get_norm() 

208 

209 def get_1dgrid(self): 

210 ''' 

211 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

212 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

213 need to do an affine transformation before any Chebychev business. 

214 

215 Returns: 

216 numpy.ndarray: 1D grid 

217 ''' 

218 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

219 

220 def get_wavenumbers(self): 

221 """Get the domain in spectral space""" 

222 return self.xp.arange(self.N) 

223 

224 def get_conv(self, name, N=None): 

225 ''' 

226 Get conversion matrix between different kinds of polynomials. The supported kinds are 

227 - T: Chebychev polynomials of first kind 

228 - U: Chebychev polynomials of second kind 

229 - D: Dirichlet recombination. 

230 

231 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

232 Once generates matrices are cached. So feel free to call the method as often as you like. 

233 

234 Args: 

235 name (str): Conversion code, e.g. 'T2U' 

236 N (int): Size of the matrix (optional) 

237 

238 Returns: 

239 scipy.sparse: Sparse conversion matrix 

240 ''' 

241 if name in self.cache.keys() and not N: 

242 return self.cache[name] 

243 

244 N = N if N else self.N 

245 sp = self.sparse_lib 

246 xp = self.xp 

247 

248 def get_forward_conv(name): 

249 if name == 'T2U': 

250 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0 

251 mat[:, 0] *= 2 

252 elif name == 'D2T': 

253 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

254 elif name[0] == name[-1]: 

255 mat = self.sparse_lib.eye(self.N) 

256 else: 

257 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

258 return mat 

259 

260 try: 

261 mat = get_forward_conv(name) 

262 except NotImplementedError as E: 

263 try: 

264 fwd = get_forward_conv(name[::-1]) 

265 import scipy.sparse as sp 

266 

267 if self.sparse_lib == sp: 

268 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

269 else: 

270 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

271 except NotImplementedError: 

272 raise NotImplementedError from E 

273 

274 self.cache[name] = mat 

275 return mat 

276 

277 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

278 """ 

279 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

280 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

281 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

282 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

283 

284 Args: 

285 conv (str): Conversion code, i.e. T2U 

286 

287 Returns: 

288 Sparse conversion matrix 

289 """ 

290 return self.get_conv(conv) 

291 

292 def get_integration_matrix(self, lbnd=0): 

293 """ 

294 Get matrix for integration 

295 

296 Args: 

297 lbnd (float): Lower bound for integration, only 0 is currently implemented 

298 

299 Returns: 

300 Sparse integration matrix 

301 """ 

302 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

303 n = self.xp.arange(self.N) 

304 if lbnd == 0: 

305 S = S.tocsc() 

306 S[0, 1::2] = ( 

307 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

308 * (-1) ** (self.xp.arange(self.N // 2)) 

309 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

310 ) * self.lin_trf_fac 

311 else: 

312 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

313 return S 

314 

315 def get_differentiation_matrix(self, p=1): 

316 ''' 

317 Keep in mind that the T2T differentiation matrix is dense. 

318 

319 Args: 

320 p (int): Derivative you want to compute 

321 

322 Returns: 

323 numpy.ndarray: Differentiation matrix 

324 ''' 

325 D = self.xp.zeros((self.N, self.N)) 

326 for j in range(self.N): 

327 for k in range(j): 

328 D[k, j] = 2 * j * ((j - k) % 2) 

329 

330 D[0, :] /= 2 

331 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

332 

333 def get_norm(self, N=None): 

334 ''' 

335 Get normalization for converting Chebychev coefficients and DCT 

336 

337 Args: 

338 N (int, optional): Resolution 

339 

340 Returns: 

341 self.xp.array: Normalization 

342 ''' 

343 N = self.N if N is None else N 

344 norm = self.xp.ones(N) / N 

345 norm[0] /= 2 

346 return norm 

347 

348 def get_fft_shuffle(self, forward, N): 

349 """ 

350 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following 

351 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it 

352 in the complex plane. This function returns a mask to do the shuffling. 

353 

354 Args: 

355 forward (bool): Whether you want the shuffle for forward transform or backward transform 

356 N (int): size of the grid 

357 

358 Returns: 

359 self.xp.array: Use as mask 

360 """ 

361 xp = self.xp 

362 if forward: 

363 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2) 

364 else: 

365 mask = xp.zeros(N, dtype=int) 

366 mask[: N - N % 2 : 2] = xp.arange(N // 2) 

367 mask[1::2] = N - xp.arange(N // 2) - 1 

368 mask[-1] = N // 2 

369 return mask 

370 

371 def get_fft_shift(self, forward, N): 

372 """ 

373 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT. 

374 

375 Args: 

376 forward (bool): Whether you want the rotation for forward transform or backward transform 

377 N (int): size of the grid 

378 

379 Returns: 

380 self.xp.array: Rotation 

381 """ 

382 k = self.get_wavenumbers() 

383 norm = self.get_norm() 

384 xp = self.xp 

385 if forward: 

386 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm 

387 else: 

388 shift = xp.exp(1j * np.pi * k / (2 * N)) 

389 shift[0] = 0.5 

390 return shift / norm 

391 

392 def get_fft_utils(self): 

393 """ 

394 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep 

395 them cached. 

396 """ 

397 self.fft_utils = { 

398 'fwd': {}, 

399 'bck': {}, 

400 } 

401 

402 # forwards transform 

403 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N) 

404 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N) 

405 

406 # backwards transform 

407 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N) 

408 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N) 

409 

410 return self.fft_utils 

411 

412 def transform(self, u, axis=-1, **kwargs): 

413 """ 

414 1D DCT along axis. `kwargs` will be passed on to the FFT library. 

415 

416 Args: 

417 u: Data you want to transform 

418 axis (int): Axis you want to transform along 

419 

420 Returns: 

421 Data in spectral space 

422 """ 

423 if self.transform_type.lower() == 'dct': 

424 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm 

425 elif self.transform_type.lower() == 'fft': 

426 result = u.copy() 

427 

428 shuffle = [slice(0, s, 1) for s in u.shape] 

429 shuffle[axis] = self.fft_utils['fwd']['shuffle'] 

430 

431 v = u[(*shuffle,)] 

432 

433 V = self.fft_lib.fft(v, axis=axis, **kwargs) 

434 

435 expansion = [np.newaxis for _ in u.shape] 

436 expansion[axis] = slice(0, u.shape[axis], 1) 

437 

438 V *= self.fft_utils['fwd']['shift'][(*expansion,)] 

439 

440 result.real[...] = V.real[...] 

441 return result 

442 else: 

443 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}') 

444 

445 def itransform(self, u, axis=-1): 

446 """ 

447 1D inverse DCT along axis. 

448 

449 Args: 

450 u: Data you want to transform 

451 axis (int): Axis you want to transform along 

452 

453 Returns: 

454 Data in physical space 

455 """ 

456 assert self.norm.shape[0] == u.shape[axis] 

457 

458 if self.transform_type == 'dct': 

459 return self.fft_lib.idct(u / self.norm, axis=axis) 

460 elif self.transform_type == 'fft': 

461 result = u.copy() 

462 

463 expansion = [np.newaxis for _ in u.shape] 

464 expansion[axis] = slice(0, u.shape[axis], 1) 

465 

466 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis) 

467 

468 shuffle = [slice(0, s, 1) for s in u.shape] 

469 shuffle[axis] = self.fft_utils['bck']['shuffle'] 

470 V = v[(*shuffle,)] 

471 

472 result.real[...] = V.real[...] 

473 return result 

474 else: 

475 raise NotImplementedError 

476 

477 def get_BC(self, kind, **kwargs): 

478 """ 

479 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

480 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

481 set the BC. 

482 

483 Args: 

484 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

485 """ 

486 if kind.lower() == 'integral': 

487 return self.get_integ_BC_row(**kwargs) 

488 elif kind.lower() == 'dirichlet': 

489 return self.get_Dirichlet_BC_row(**kwargs) 

490 else: 

491 return super().get_BC(kind) 

492 

493 def get_integ_BC_row(self): 

494 """ 

495 Get a row for generating integral BCs with T polynomials. 

496 It returns the values of the integrals of T polynomials over the entire interval. 

497 

498 Returns: 

499 self.xp.ndarray: Row to put into a matrix 

500 """ 

501 n = self.xp.arange(self.N) + 1 

502 me = self.xp.zeros_like(n).astype(float) 

503 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

504 me[0] = 2.0 

505 return me 

506 

507 def get_Dirichlet_BC_row(self, x): 

508 """ 

509 Get a row for generating Dirichlet BCs at x with T polynomials. 

510 It returns the values of the T polynomials at x. 

511 

512 Args: 

513 x (float): Position of the boundary condition 

514 

515 Returns: 

516 self.xp.ndarray: Row to put into a matrix 

517 """ 

518 if x == -1: 

519 return (-1) ** self.xp.arange(self.N) 

520 elif x == 1: 

521 return self.xp.ones(self.N) 

522 elif x == 0: 

523 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

524 n[2::4] *= -1 

525 return n 

526 else: 

527 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

528 

529 def get_Dirichlet_recombination_matrix(self): 

530 ''' 

531 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

532 This makes for a good right preconditioner. 

533 

534 Returns: 

535 scipy.sparse: Sparse conversion matrix 

536 ''' 

537 N = self.N 

538 sp = self.sparse_lib 

539 xp = self.xp 

540 

541 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

542 

543 

544class UltrasphericalHelper(ChebychevHelper): 

545 """ 

546 This implementation follows https://doi.org/10.1137/120865458. 

547 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

548 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

549 There are also conversion operators from one derivative basis to the next that are sparse. 

550 

551 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

552 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

553 """ 

554 

555 def get_differentiation_matrix(self, p=1): 

556 """ 

557 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

558 

559 Args: 

560 p (int): Order of the derivative 

561 

562 Returns: 

563 sparse differentiation matrix 

564 """ 

565 sp = self.sparse_lib 

566 xp = self.xp 

567 N = self.N 

568 l = p 

569 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

570 

571 def get_S(self, lmbda): 

572 """ 

573 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

574 https://doi.org/10.1137/120865458. 

575 

576 Args: 

577 lmbda (int): Ingoing derivative base 

578 

579 Returns: 

580 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

581 """ 

582 N = self.N 

583 

584 if lmbda == 0: 

585 sp = scipy.sparse 

586 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil() 

587 mat[:, 0] *= 2 

588 else: 

589 sp = self.sparse_lib 

590 xp = self.xp 

591 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

592 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

593 ) 

594 

595 return self.sparse_lib.csc_matrix(mat) 

596 

597 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

598 """ 

599 Get a conversion matrix from derivative base `p_in` to `p_out`. 

600 

601 Args: 

602 p_out (int): Resulting derivative base 

603 p_in (int): Ingoing derivative base 

604 """ 

605 mat_fwd = self.sparse_lib.eye(self.N) 

606 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

607 mat_fwd = self.get_S(i) @ mat_fwd 

608 

609 if p_out > p_in: 

610 return mat_fwd 

611 

612 else: 

613 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

614 import scipy.sparse as sp 

615 

616 if self.useGPU: 

617 mat_fwd = mat_fwd.get() 

618 

619 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

620 

621 return self.sparse_lib.csc_matrix(mat_bck) 

622 

623 def get_integration_matrix(self): 

624 """ 

625 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

626 integration constant such that integration starts from x=-1. 

627 

628 Example: 

629 

630 .. code-block:: python 

631 

632 import numpy as np 

633 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

634 

635 N = 4 

636 helper = UltrasphericalHelper(N) 

637 coeffs = np.random.random(N) 

638 coeffs[-1] = 0 

639 

640 poly = np.polynomial.Chebyshev(coeffs) 

641 

642 S = helper.get_integration_matrix() 

643 U_hat = S @ coeffs 

644 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

645 

646 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

647 

648 Returns: 

649 sparse integration matrix 

650 """ 

651 return ( 

652 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

653 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

654 * self.lin_trf_fac 

655 ) 

656 

657 def get_integration_constant(self, u_hat, axis): 

658 """ 

659 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

660 

661 Args: 

662 u_hat: Solution in spectral space 

663 axis: Axis you want to integrate over 

664 

665 Returns: 

666 Integration constant, has one less dimension than `u_hat` 

667 """ 

668 slices = [ 

669 None, 

670 ] * u_hat.ndim 

671 slices[axis] = slice(1, u_hat.shape[axis]) 

672 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

673 

674 

675class FFTHelper(SpectralHelper1D): 

676 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

677 """ 

678 Constructor. 

679 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

680 may choose to run on GPUs via the `useGPU` argument. 

681 

682 Args: 

683 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

684 use the FFT from the library to compute the DCT 

685 x0 (float, optional): Coordinate of left boundary 

686 x1 (float, optional): Coordinate of right boundary 

687 """ 

688 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

689 

690 def get_1dgrid(self): 

691 """ 

692 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

693 """ 

694 dx = self.L / self.N 

695 return self.xp.arange(self.N) * dx + self.x0 

696 

697 def get_wavenumbers(self): 

698 """ 

699 Be careful that this ordering is very unintuitive. 

700 """ 

701 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

702 

703 def get_differentiation_matrix(self, p=1): 

704 """ 

705 This matrix is diagonal, allowing to invert concurrently. 

706 

707 Args: 

708 p (int): Order of the derivative 

709 

710 Returns: 

711 sparse differentiation matrix 

712 """ 

713 k = self.get_wavenumbers() 

714 

715 if self.useGPU: 

716 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

717 import scipy.sparse as sp 

718 

719 D = self.sparse_lib.diags(1j * k).get() 

720 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p)) 

721 else: 

722 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

723 

724 def get_integration_matrix(self, p=1): 

725 """ 

726 Get integration matrix to compute `p`-th integral over the entire domain. 

727 

728 Args: 

729 p (int): Order of integral you want to compute 

730 

731 Returns: 

732 sparse integration matrix 

733 """ 

734 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

735 k[0] = 1j * self.L 

736 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

737 

738 def transform(self, u, axis=-1, **kwargs): 

739 """ 

740 1D FFT along axis. `kwargs` are passed on to the FFT library. 

741 

742 Args: 

743 u: Data you want to transform 

744 axis (int): Axis you want to transform along 

745 

746 Returns: 

747 transformed data 

748 """ 

749 return self.fft_lib.fft(u, axis=axis, **kwargs) 

750 

751 def itransform(self, u, axis=-1): 

752 """ 

753 Inverse 1D FFT. 

754 

755 Args: 

756 u: Data you want to transform 

757 axis (int): Axis you want to transform along 

758 

759 Returns: 

760 transformed data 

761 """ 

762 return self.fft_lib.ifft(u, axis=axis) 

763 

764 def get_BC(self, kind): 

765 """ 

766 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

767 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

768 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

769 

770 Args: 

771 kind ('integral' or 'nyquist'): Kind of BC 

772 

773 Returns: 

774 self.xp.ndarray: Boundary condition row 

775 """ 

776 if kind.lower() == 'integral': 

777 return self.get_integ_BC_row() 

778 elif kind.lower() == 'nyquist': 

779 assert ( 

780 self.N % 2 == 0 

781 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

782 BC = self.xp.zeros(self.N) 

783 BC[self.get_Nyquist_mode_index()] = 1 

784 return BC 

785 else: 

786 return super().get_BC(kind) 

787 

788 def get_Nyquist_mode_index(self): 

789 """ 

790 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

791 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

792 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

793 after. 

794 

795 Returns: 

796 int: Index of the Nyquist mode 

797 """ 

798 k = self.get_wavenumbers() 

799 Nyquist_mode = min(k) 

800 return self.xp.where(k == Nyquist_mode)[0][0] 

801 

802 def get_integ_BC_row(self): 

803 """ 

804 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

805 """ 

806 me = self.xp.zeros(self.N) 

807 me[0] = self.L / self.N 

808 return me 

809 

810 

811class SpectralHelper: 

812 """ 

813 This class has three functions: 

814 - Easily assemble matrices containing multiple equations 

815 - Direct product of 1D bases to solve problems in more dimensions 

816 - Distribute the FFTs to facilitate concurrency. 

817 

818 Attributes: 

819 comm (mpi4py.Intracomm): MPI communicator 

820 debug (bool): Perform additional checks at extra computational cost 

821 useGPU (bool): Whether to use GPUs 

822 axes (list): List of 1D bases 

823 components (list): List of strings of the names of components in the equations 

824 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

825 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

826 BCs (sparse matrix): Matrix containing only the BCs 

827 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

828 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

829 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

830 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

831 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

832 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

833 local_slice (slice): Local slice of the solution as in `mpi4py-fft` 

834 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

835 init (tuple): This is the same `init` that is used throughout the problem classes 

836 init_forward (tuple): This is the equivalent of `init` in spectral space 

837 """ 

838 

839 xp = np 

840 fft_lib = scipy.fft 

841 sparse_lib = scipy.sparse 

842 linalg = scipy.sparse.linalg 

843 dtype = mesh 

844 fft_backend = 'fftw' 

845 fft_comm_backend = 'MPI' 

846 

847 @classmethod 

848 def setup_GPU(cls): 

849 """switch to GPU modules""" 

850 import cupy as cp 

851 import cupyx.scipy.sparse as sparse_lib 

852 import cupyx.scipy.sparse.linalg as linalg 

853 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

854 

855 cls.xp = cp 

856 cls.sparse_lib = sparse_lib 

857 cls.linalg = linalg 

858 

859 cls.fft_backend = 'cupy' 

860 cls.fft_comm_backend = 'NCCL' 

861 

862 cls.dtype = cupy_mesh 

863 

864 def __init__(self, comm=None, useGPU=False, debug=False): 

865 """ 

866 Constructor 

867 

868 Args: 

869 comm (mpi4py.Intracomm): MPI communicator 

870 useGPU (bool): Whether to use GPUs 

871 debug (bool): Perform additional checks at extra computational cost 

872 """ 

873 self.comm = comm 

874 self.debug = debug 

875 self.useGPU = useGPU 

876 

877 if useGPU: 

878 self.setup_GPU() 

879 

880 self.axes = [] 

881 self.components = [] 

882 

883 self.full_BCs = [] 

884 self.BC_mat = None 

885 self.BCs = None 

886 

887 self.fft_cache = {} 

888 self.fft_dealias_shape_cache = {} 

889 

890 @property 

891 def u_init(self): 

892 """ 

893 Get empty data container in physical space 

894 """ 

895 return self.dtype(self.init) 

896 

897 @property 

898 def u_init_forward(self): 

899 """ 

900 Get empty data container in spectral space 

901 """ 

902 return self.dtype(self.init_forward) 

903 

904 @property 

905 def shape(self): 

906 """ 

907 Get shape of individual solution component 

908 """ 

909 return self.init[0][1:] 

910 

911 @property 

912 def ndim(self): 

913 return len(self.axes) 

914 

915 @property 

916 def ncomponents(self): 

917 return len(self.components) 

918 

919 @property 

920 def V(self): 

921 """ 

922 Get domain volume 

923 """ 

924 return np.prod([me.L for me in self.axes]) 

925 

926 def add_axis(self, base, *args, **kwargs): 

927 """ 

928 Add an axis to the domain by deciding on suitable 1D base. 

929 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

930 bases for possible arguments. 

931 

932 Args: 

933 base (str): 1D spectral method 

934 """ 

935 kwargs['useGPU'] = self.useGPU 

936 

937 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

938 kwargs['transform_type'] = kwargs.get('transform_type', 'fft') 

939 self.axes.append(ChebychevHelper(*args, **kwargs)) 

940 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

941 self.axes.append(FFTHelper(*args, **kwargs)) 

942 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

943 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

944 else: 

945 raise NotImplementedError(f'{base=!r} is not implemented!') 

946 self.axes[-1].xp = self.xp 

947 self.axes[-1].sparse_lib = self.sparse_lib 

948 

949 def add_component(self, name): 

950 """ 

951 Add solution component(s). 

952 

953 Args: 

954 name (str or list of strings): Name(s) of component(s) 

955 """ 

956 if type(name) in [list, tuple]: 

957 for me in name: 

958 self.add_component(me) 

959 elif type(name) in [str]: 

960 if name in self.components: 

961 raise Exception(f'{name=!r} is already added to this problem!') 

962 self.components.append(name) 

963 else: 

964 raise NotImplementedError 

965 

966 def index(self, name): 

967 """ 

968 Get the index of component `name`. 

969 

970 Args: 

971 name (str or list of strings): Name(s) of component(s) 

972 

973 Returns: 

974 int: Index of the component 

975 """ 

976 if type(name) in [str, int]: 

977 return self.components.index(name) 

978 elif type(name) in [list, tuple]: 

979 return (self.index(me) for me in name) 

980 else: 

981 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

982 

983 def get_empty_operator_matrix(self): 

984 """ 

985 Return a matrix of operators to be filled with the connections between the solution components. 

986 

987 Returns: 

988 list containing sparse zeros 

989 """ 

990 S = len(self.components) 

991 O = self.get_Id() * 0 

992 return [[O for _ in range(S)] for _ in range(S)] 

993 

994 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

995 """ 

996 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

997 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

998 Typically, the last line that does not contain a BC is the best choice. 

999 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1000 

1001 Args: 

1002 axis (int): Axis you want to add the BC to 

1003 kind (str): kind of BC, e.g. Dirichlet 

1004 line (int): Line you want the BC to go in 

1005 scalar (bool): Put the BC in all space positions in the other direction 

1006 

1007 Returns: 

1008 sparse matrix containing the BC 

1009 """ 

1010 sp = scipy.sparse 

1011 

1012 base = self.axes[axis] 

1013 

1014 BC = sp.eye(base.N).tolil() * 0 

1015 if self.useGPU: 

1016 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1017 else: 

1018 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1019 

1020 ndim = len(self.axes) 

1021 if ndim == 1: 

1022 return self.sparse_lib.csc_matrix(BC) 

1023 elif ndim == 2: 

1024 axis2 = (axis + 1) % ndim 

1025 

1026 if scalar: 

1027 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1028 else: 

1029 _Id = self.axes[axis2].get_Id() 

1030 

1031 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1032 

1033 if self.useGPU: 

1034 Id = Id.get() 

1035 

1036 mats = [ 

1037 None, 

1038 ] * ndim 

1039 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1040 mats[axis2] = Id 

1041 return self.sparse_lib.csc_matrix(sp.kron(*mats)) 

1042 

1043 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1044 """ 

1045 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1046 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1047 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1048 

1049 Args: 

1050 component (str): Name of the component the BC should act on 

1051 equation (str): Name of the equation for the component you want to put the BC in 

1052 axis (int): Axis you want to add the BC to 

1053 kind (str): kind of BC, e.g. Dirichlet 

1054 v: Value of the BC 

1055 line (int): Line you want the BC to go in 

1056 scalar (bool): Put the BC in all space positions in the other direction 

1057 """ 

1058 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1059 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1060 

1061 if scalar: 

1062 slices = [self.index(equation)] + [ 

1063 0, 

1064 ] * self.ndim 

1065 slices[axis + 1] = line 

1066 else: 

1067 slices = ( 

1068 [self.index(equation)] 

1069 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1070 + [line] 

1071 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1072 ) 

1073 N = self.axes[axis].N 

1074 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1075 self.BC_rhs_mask[(*slices,)] = False 

1076 

1077 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1078 """ 

1079 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1080 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1081 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1082 

1083 Args: 

1084 component (str): Name of the component the BC should act on 

1085 equation (str): Name of the equation for the component you want to put the BC in 

1086 axis (int): Axis you want to add the BC to 

1087 kind (str): kind of BC, e.g. Dirichlet 

1088 v: Value of the BC 

1089 line (int): Line you want the BC to go in 

1090 scalar (bool): Put the BC in all space positions in the other direction 

1091 """ 

1092 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1093 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1094 self.full_BCs += [ 

1095 { 

1096 'component': component, 

1097 'equation': equation, 

1098 'axis': axis, 

1099 'kind': kind, 

1100 'v': v, 

1101 'line': line, 

1102 'scalar': scalar, 

1103 **kwargs, 

1104 } 

1105 ] 

1106 

1107 if scalar: 

1108 slices = [self.index(equation)] + [ 

1109 0, 

1110 ] * self.ndim 

1111 slices[axis + 1] = line 

1112 if self.comm: 

1113 if self.comm.rank == 0: 

1114 self.BC_rhs_mask[(*slices,)] = True 

1115 else: 

1116 self.BC_rhs_mask[(*slices,)] = True 

1117 else: 

1118 slices = ( 

1119 [self.index(equation)] 

1120 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1121 + [line] 

1122 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1123 ) 

1124 N = self.axes[axis].N 

1125 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1126 slices[axis + 1] -= self.local_slice[axis].start 

1127 self.BC_rhs_mask[(*slices,)] = True 

1128 

1129 def setup_BCs(self): 

1130 """ 

1131 Convert the list of lists of BCs to the boundary condition operator. 

1132 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1133 condition. This method sets up a suitable sparse matrix to do this. 

1134 """ 

1135 sp = self.sparse_lib 

1136 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1137 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1138 

1139 diags = self.xp.ones(self.BCs.shape[0]) 

1140 diags[self.BC_zero_index] = 0 

1141 self.BC_line_zero_matrix = sp.diags(diags) 

1142 

1143 # prepare BCs in spectral space to easily add to the RHS 

1144 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1145 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1146 

1147 def check_BCs(self, u): 

1148 """ 

1149 Check that the solution satisfies the boundary conditions 

1150 

1151 Args: 

1152 u: The solution you want to check 

1153 """ 

1154 assert self.ndim < 3 

1155 for axis in range(self.ndim): 

1156 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1157 

1158 if len(BCs) > 0: 

1159 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1160 for BC in BCs: 

1161 kwargs = { 

1162 key: value 

1163 for key, value in BC.items() 

1164 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1165 } 

1166 

1167 if axis == 0: 

1168 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1169 elif axis == 1: 

1170 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1171 want = BC['v'] 

1172 assert self.xp.allclose( 

1173 get, want 

1174 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1175 

1176 def put_BCs_in_matrix(self, A): 

1177 """ 

1178 Put the boundary conditions in a matrix by replacing rows with BCs. 

1179 """ 

1180 return self.BC_line_zero_matrix @ A + self.BCs 

1181 

1182 def put_BCs_in_rhs_hat(self, rhs_hat): 

1183 """ 

1184 Put the BCs in the right hand side in spectral space for solving. 

1185 This function needs no transforms and caches a mask for faster subsequent use. 

1186 

1187 Args: 

1188 rhs_hat: Right hand side in spectral space 

1189 

1190 Returns: 

1191 rhs in spectral space with BCs 

1192 """ 

1193 if not hasattr(self, '_rhs_hat_zero_mask'): 

1194 """ 

1195 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1196 by the boundary conditions. The mask is then cached. 

1197 """ 

1198 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool) 

1199 

1200 for axis in range(self.ndim): 

1201 for bc in self.full_BCs: 

1202 slices = ( 

1203 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1204 + [bc['line']] 

1205 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1206 ) 

1207 if axis == bc['axis']: 

1208 _slice = [self.index(bc['equation'])] + slices 

1209 N = self.axes[axis].N 

1210 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1211 _slice[axis + 1] -= self.local_slice[axis].start 

1212 self._rhs_hat_zero_mask[(*_slice,)] = True 

1213 

1214 rhs_hat[self._rhs_hat_zero_mask] = 0 

1215 return rhs_hat + self.rhs_BCs_hat 

1216 

1217 def put_BCs_in_rhs(self, rhs): 

1218 """ 

1219 Put the BCs in the right hand side for solving. 

1220 This function will transform along each axis individually and add all BCs in that axis. 

1221 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1222 

1223 Args: 

1224 rhs: Right hand side in physical space 

1225 

1226 Returns: 

1227 rhs in physical space with BCs 

1228 """ 

1229 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1230 

1231 ndim = self.ndim 

1232 

1233 for axis in range(ndim): 

1234 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1235 

1236 for bc in self.full_BCs: 

1237 slices = ( 

1238 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1239 + [bc['line']] 

1240 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1241 ) 

1242 if axis == bc['axis']: 

1243 _slice = [self.index(bc['equation'])] + slices 

1244 

1245 N = self.axes[axis].N 

1246 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1247 _slice[axis + 1] -= self.local_slice[axis].start 

1248 

1249 _rhs_hat[(*_slice,)] = bc['v'] 

1250 

1251 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1252 

1253 return rhs 

1254 

1255 def add_equation_lhs(self, A, equation, relations): 

1256 """ 

1257 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1258 that you will convert to an operator later. 

1259 

1260 Example: 

1261 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1262 preconditioning: 

1263 

1264 .. code-block:: python 

1265 helper = SpectralHelper() 

1266 

1267 helper.add_axis(base='chebychev', N=8) 

1268 helper.add_component(['u', 'ux']) 

1269 helper.setup_fft() 

1270 

1271 I = helper.get_Id() 

1272 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1273 T2U = helper.get_basis_change_matrix('T2U') 

1274 

1275 L_lhs = { 

1276 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1277 'u': {'ux': -(T2U @ Dx)}, 

1278 } 

1279 

1280 operator = helper.get_empty_operator_matrix() 

1281 for line, equation in L_lhs.items(): 

1282 helper.add_equation_lhs(operator, line, equation) 

1283 

1284 L = helper.convert_operator_matrix_to_operator(operator) 

1285 

1286 Args: 

1287 A (list of lists of sparse matrices): The operator to be 

1288 equation (str): The equation of the component you want this in 

1289 relations: (dict): Relations between quantities 

1290 """ 

1291 for k, v in relations.items(): 

1292 A[self.index(equation)][self.index(k)] = v 

1293 

1294 def convert_operator_matrix_to_operator(self, M): 

1295 """ 

1296 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1297 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1298 

1299 Args: 

1300 M (list of lists of sparse matrices): The operator to be 

1301 

1302 Returns: 

1303 sparse linear operator 

1304 """ 

1305 if len(self.components) == 1: 

1306 return M[0][0] 

1307 else: 

1308 return self.sparse_lib.bmat(M, format='csc') 

1309 

1310 def get_wavenumbers(self): 

1311 """ 

1312 Get grid in spectral space 

1313 """ 

1314 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))][::-1] 

1315 return self.xp.meshgrid(*grids) 

1316 

1317 def get_grid(self): 

1318 """ 

1319 Get grid in physical space 

1320 """ 

1321 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))][::-1] 

1322 return self.xp.meshgrid(*grids) 

1323 

1324 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1325 """ 

1326 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1327 

1328 Args: 

1329 axes (tuple): Axes you want to transform over 

1330 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1331 padding (tuple): Padding for dealiasing 

1332 shape (tuple): Shape of the transform 

1333 

1334 Returns: 

1335 transform 

1336 """ 

1337 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1338 shape = self.global_shape[1:] if shape is None else shape 

1339 padding = ( 

1340 [ 

1341 1, 

1342 ] 

1343 * self.ndim 

1344 if padding is None 

1345 else padding 

1346 ) 

1347 key = (axes, direction, tuple(padding), tuple(shape)) 

1348 

1349 if key not in self.fft_cache.keys(): 

1350 if self.comm is None: 

1351 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1352 

1353 if direction == 'forward': 

1354 self.fft_cache[key] = self.xp.fft.fftn 

1355 elif direction == 'backward': 

1356 self.fft_cache[key] = self.xp.fft.ifftn 

1357 elif direction == 'object': 

1358 self.fft_cache[key] = None 

1359 else: 

1360 if direction == 'object': 

1361 from mpi4py_fft import PFFT 

1362 

1363 _fft = PFFT( 

1364 comm=self.comm, 

1365 shape=shape, 

1366 axes=sorted(axes), 

1367 dtype='D', 

1368 collapse=False, 

1369 backend=self.fft_backend, 

1370 comm_backend=self.fft_comm_backend, 

1371 padding=padding, 

1372 ) 

1373 else: 

1374 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1375 

1376 if direction == 'forward': 

1377 self.fft_cache[key] = _fft.forward 

1378 elif direction == 'backward': 

1379 self.fft_cache[key] = _fft.backward 

1380 elif direction == 'object': 

1381 self.fft_cache[key] = _fft 

1382 

1383 return self.fft_cache[key] 

1384 

1385 def setup_fft(self, real_spectral_coefficients=False): 

1386 """ 

1387 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1388 This must also be called before setting up any BCs. 

1389 

1390 Args: 

1391 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1392 """ 

1393 if len(self.components) == 0: 

1394 self.add_component('u') 

1395 

1396 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1397 self.local_slice = [slice(0, me.N) for me in self.axes] 

1398 

1399 axes = tuple(i for i in range(len(self.axes))) 

1400 self.fft_obj = self.get_fft(axes=axes, direction='object') 

1401 if self.fft_obj is not None: 

1402 self.local_slice = self.fft_obj.local_slice(False) 

1403 

1404 self.init = ( 

1405 np.empty(shape=self.global_shape)[ 

1406 ( 

1407 ..., 

1408 *self.local_slice, 

1409 ) 

1410 ].shape, 

1411 self.comm, 

1412 np.dtype('float'), 

1413 ) 

1414 self.init_forward = ( 

1415 np.empty(shape=self.global_shape)[ 

1416 ( 

1417 ..., 

1418 *self.local_slice, 

1419 ) 

1420 ].shape, 

1421 self.comm, 

1422 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1423 ) 

1424 

1425 self.BC_mat = self.get_empty_operator_matrix() 

1426 self.BC_rhs_mask = self.xp.zeros( 

1427 shape=self.init[0], 

1428 dtype=bool, 

1429 ) 

1430 

1431 def _transform_fft(self, u, axes, **kwargs): 

1432 """ 

1433 FFT along `axes` 

1434 

1435 Args: 

1436 u: The solution 

1437 axes (tuple): Axes you want to transform over 

1438 

1439 Returns: 

1440 transformed solution 

1441 """ 

1442 # TODO: clean up and try putting more of this in the 1D bases 

1443 fft = self.get_fft(axes, 'forward', **kwargs) 

1444 return fft(u, axes=axes) 

1445 

1446 def _transform_dct(self, u, axes, padding=None, **kwargs): 

1447 ''' 

1448 DCT along `axes`. 

1449 This will only return real values! 

1450 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of 

1451 wavenumbers in FFTs. 

1452 

1453 Args: 

1454 u: The solution 

1455 axes (tuple): Axes you want to transform over 

1456 

1457 Returns: 

1458 transformed solution 

1459 ''' 

1460 # TODO: clean up and try putting more of this in the 1D bases 

1461 if self.debug: 

1462 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1463 

1464 if len(axes) > 1: 

1465 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs) 

1466 else: 

1467 v = u.copy().astype(complex) 

1468 axis = axes[0] 

1469 base = self.axes[axis] 

1470 

1471 shuffle = [slice(0, s, 1) for s in u.shape] 

1472 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis]) 

1473 v = v[(*shuffle,)] 

1474 

1475 if padding is not None: 

1476 shape = list(v.shape) 

1477 if ('forward', *padding) in self.fft_dealias_shape_cache.keys(): 

1478 shape[0] = self.fft_dealias_shape_cache[('forward', *padding)] 

1479 elif self.comm: 

1480 send_buf = np.array(v.shape[0]) 

1481 recv_buf = np.array(v.shape[0]) 

1482 self.comm.Allreduce(send_buf, recv_buf) 

1483 shape[0] = int(recv_buf) 

1484 fft = self.get_fft(axes, 'forward', shape=shape) 

1485 else: 

1486 fft = self.get_fft(axes, 'forward', **kwargs) 

1487 

1488 v = fft(v, axes=axes) 

1489 

1490 expansion = [np.newaxis for _ in u.shape] 

1491 expansion[axis] = slice(0, v.shape[axis], 1) 

1492 

1493 if padding is not None: 

1494 shift = base.get_fft_shift(True, v.shape[axis]) 

1495 

1496 if padding[axis] != 1: 

1497 N = int(np.ceil(v.shape[axis] / padding[axis])) 

1498 _expansion = [slice(0, n) for n in v.shape] 

1499 _expansion[axis] = slice(0, N, 1) 

1500 v = v[(*_expansion,)] 

1501 else: 

1502 shift = base.fft_utils['fwd']['shift'] 

1503 

1504 v *= shift[(*expansion,)] 

1505 

1506 return v.real 

1507 

1508 def transform_single_component(self, u, axes=None, padding=None): 

1509 """ 

1510 Transform a single component of the solution 

1511 

1512 Args: 

1513 u data to transform: 

1514 axes (tuple): Axes over which to transform 

1515 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1516 

1517 Returns: 

1518 Transformed data 

1519 """ 

1520 # TODO: clean up and try putting more of this in the 1D bases 

1521 trfs = { 

1522 ChebychevHelper: self._transform_dct, 

1523 UltrasphericalHelper: self._transform_dct, 

1524 FFTHelper: self._transform_fft, 

1525 } 

1526 

1527 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1528 padding = ( 

1529 [ 

1530 1, 

1531 ] 

1532 * self.ndim 

1533 if padding is None 

1534 else padding 

1535 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only. 

1536 

1537 result = u.copy().astype(complex) 

1538 alignment = self.ndim - 1 

1539 

1540 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1541 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1542 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1543 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)] 

1544 

1545 fft = self.get_fft(axes=axes, padding=padding, direction='object') 

1546 if fft is not None: 

1547 shape = list(fft.global_shape(False)) 

1548 

1549 for trf in range(len(axes_collapsed)): 

1550 _axes = axes_collapsed[trf] 

1551 base = bases[trf] 

1552 

1553 if len(_axes) == 0: 

1554 continue 

1555 

1556 for _ax in _axes: 

1557 shape[_ax] = self.global_shape[1 + self.ndim + _ax] 

1558 

1559 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1560 

1561 _in = self.get_aligned( 

1562 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape 

1563 ) 

1564 

1565 alignment = self.ndim + _axes[-1] 

1566 

1567 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1568 

1569 if self.comm is not None: 

1570 _out *= np.prod([self.axes[i].N for i in _axes]) 

1571 

1572 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1573 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1] 

1574 result = self.get_aligned( 

1575 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape 

1576 ) 

1577 

1578 return result 

1579 

1580 def transform(self, u, axes=None, padding=None): 

1581 """ 

1582 Transform all components from physical space to spectral space 

1583 

1584 Args: 

1585 u data to transform: 

1586 axes (tuple): Axes over which to transform 

1587 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1588 

1589 Returns: 

1590 Transformed data 

1591 """ 

1592 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1593 padding = ( 

1594 [ 

1595 1, 

1596 ] 

1597 * self.ndim 

1598 if padding is None 

1599 else padding 

1600 ) 

1601 

1602 result = [ 

1603 None, 

1604 ] * self.ncomponents 

1605 for comp in self.components: 

1606 i = self.index(comp) 

1607 

1608 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding) 

1609 

1610 return self.xp.stack(result) 

1611 

1612 def _transform_ifft(self, u, axes, **kwargs): 

1613 # TODO: clean up and try putting more of this in the 1D bases 

1614 ifft = self.get_fft(axes, 'backward', **kwargs) 

1615 return ifft(u, axes=axes) 

1616 

1617 def _transform_idct(self, u, axes, padding=None, **kwargs): 

1618 ''' 

1619 This will only ever return real values! 

1620 ''' 

1621 # TODO: clean up and try putting more of this in the 1D bases 

1622 if self.debug: 

1623 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1624 

1625 v = u.copy().astype(complex) 

1626 

1627 if len(axes) > 1: 

1628 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],)) 

1629 else: 

1630 axis = axes[0] 

1631 base = self.axes[axis] 

1632 

1633 if padding is not None: 

1634 if padding[axis] != 1: 

1635 N_pad = int(np.ceil(v.shape[axis] * padding[axis])) 

1636 _pad = [[0, 0] for _ in v.shape] 

1637 _pad[axis] = [0, N_pad - base.N] 

1638 v = self.xp.pad(v, _pad, 'constant') 

1639 

1640 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N 

1641 else: 

1642 shift = base.fft_utils['bck']['shift'] 

1643 else: 

1644 shift = base.fft_utils['bck']['shift'] 

1645 

1646 expansion = [np.newaxis for _ in u.shape] 

1647 expansion[axis] = slice(0, v.shape[axis], 1) 

1648 

1649 v *= shift[(*expansion,)] 

1650 

1651 if padding is not None: 

1652 if padding[axis] != 1: 

1653 shape = list(v.shape) 

1654 if ('backward', *padding) in self.fft_dealias_shape_cache.keys(): 

1655 shape[0] = self.fft_dealias_shape_cache[('backward', *padding)] 

1656 elif self.comm: 

1657 send_buf = np.array(v.shape[0]) 

1658 recv_buf = np.array(v.shape[0]) 

1659 self.comm.Allreduce(send_buf, recv_buf) 

1660 shape[0] = int(recv_buf) 

1661 ifft = self.get_fft(axes, 'backward', shape=shape) 

1662 else: 

1663 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1664 else: 

1665 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1666 v = ifft(v, axes=axes) 

1667 

1668 shuffle = [slice(0, s, 1) for s in v.shape] 

1669 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis]) 

1670 v = v[(*shuffle,)] 

1671 

1672 return v.real 

1673 

1674 def itransform_single_component(self, u, axes=None, padding=None): 

1675 """ 

1676 Inverse transform over single component of the solution 

1677 

1678 Args: 

1679 u data to transform: 

1680 axes (tuple): Axes over which to transform 

1681 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1682 

1683 Returns: 

1684 Transformed data 

1685 """ 

1686 # TODO: clean up and try putting more of this in the 1D bases 

1687 trfs = { 

1688 FFTHelper: self._transform_ifft, 

1689 ChebychevHelper: self._transform_idct, 

1690 UltrasphericalHelper: self._transform_idct, 

1691 } 

1692 

1693 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1694 padding = ( 

1695 [ 

1696 1, 

1697 ] 

1698 * self.ndim 

1699 if padding is None 

1700 else padding 

1701 ) 

1702 

1703 result = u.copy().astype(complex) 

1704 alignment = self.ndim - 1 

1705 

1706 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1707 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1708 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1709 shape = list(self.global_shape[1:]) 

1710 

1711 for trf in range(len(axes_collapsed)): 

1712 _axes = axes_collapsed[trf] 

1713 base = bases[trf] 

1714 

1715 if len(_axes) == 0: 

1716 continue 

1717 

1718 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1719 

1720 _in = self.get_aligned( 

1721 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape 

1722 ) 

1723 if self.comm is not None: 

1724 _in /= np.prod([self.axes[i].N for i in _axes]) 

1725 

1726 alignment = self.ndim + _axes[0] 

1727 

1728 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1729 

1730 for _ax in _axes: 

1731 if fft: 

1732 shape[_ax] = fft._input_shape[_ax] 

1733 else: 

1734 shape[_ax] = _out.shape[_ax] 

1735 

1736 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1737 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0] 

1738 result = self.get_aligned( 

1739 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape 

1740 ) 

1741 

1742 return result 

1743 

1744 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs): 

1745 """ 

1746 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT 

1747 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include 

1748 shape especially, if applicable. 

1749 

1750 Args: 

1751 u: The solution 

1752 axis_in (int): Current alignment 

1753 axis_out (int): New alignment 

1754 fft (mpi4py_fft.PFFT), optional: parallel FFT object 

1755 forward (bool): Whether the input is in spectral space or not 

1756 

1757 Returns: 

1758 solution aligned on `axis_in` 

1759 """ 

1760 if self.comm is None or axis_in == axis_out: 

1761 return u.copy() 

1762 if self.comm.size == 1: 

1763 return u.copy() 

1764 

1765 global_fft = self.get_fft(**kwargs) 

1766 axisA = [me.axisA for me in global_fft.transfer] 

1767 axisB = [me.axisB for me in global_fft.transfer] 

1768 

1769 current_axis = axis_in 

1770 

1771 if axis_in in axisA and axis_out in axisB: 

1772 while current_axis != axis_out: 

1773 transfer = global_fft.transfer[axisA.index(current_axis)] 

1774 

1775 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1776 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1777 arrayA[:] = u[:] 

1778 

1779 transfer.forward(arrayA=arrayA, arrayB=arrayB) 

1780 

1781 current_axis = transfer.axisB 

1782 u = arrayB 

1783 

1784 return u 

1785 elif axis_in in axisB and axis_out in axisA: 

1786 while current_axis != axis_out: 

1787 transfer = global_fft.transfer[axisB.index(current_axis)] 

1788 

1789 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1790 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1791 arrayB[:] = u[:] 

1792 

1793 transfer.backward(arrayA=arrayA, arrayB=arrayB) 

1794 

1795 current_axis = transfer.axisA 

1796 u = arrayA 

1797 

1798 return u 

1799 else: # go the potentially slower route of not reusing transfer classes 

1800 from mpi4py_fft import newDistArray 

1801 

1802 fft = self.get_fft(**kwargs) if fft is None else fft 

1803 

1804 _in = newDistArray(fft, forward).redistribute(axis_in) 

1805 _in[...] = u 

1806 

1807 return _in.redistribute(axis_out) 

1808 

1809 def itransform(self, u, axes=None, padding=None): 

1810 """ 

1811 Inverse transform over all components of the solution 

1812 

1813 Args: 

1814 u data to transform: 

1815 axes (tuple): Axes over which to transform 

1816 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1817 

1818 Returns: 

1819 Transformed data 

1820 """ 

1821 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1822 padding = ( 

1823 [ 

1824 1, 

1825 ] 

1826 * self.ndim 

1827 if padding is None 

1828 else padding 

1829 ) 

1830 

1831 result = [ 

1832 None, 

1833 ] * self.ncomponents 

1834 for comp in self.components: 

1835 i = self.index(comp) 

1836 

1837 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding) 

1838 

1839 return self.xp.stack(result) 

1840 

1841 def get_local_slice_of_1D_matrix(self, M, axis): 

1842 """ 

1843 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1844 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can 

1845 use this method to get the part corresponding to the modes carried by this rank. 

1846 

1847 Args: 

1848 M (sparse matrix): Global 1D matrix you want to get the local version of 

1849 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only. 

1850 

1851 Returns: 

1852 sparse local matrix 

1853 """ 

1854 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]] 

1855 

1856 def get_filter_matrix(self, axis, **kwargs): 

1857 """ 

1858 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1859 admissible. 

1860 

1861 Returns: 

1862 sparse bandpass matrix 

1863 """ 

1864 if self.ndim == 1: 

1865 return self.axes[0].get_filter_matrix(**kwargs) 

1866 

1867 mats = [base.get_Id() for base in self.axes] 

1868 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1869 return self.sparse_lib.kron(*mats) 

1870 

1871 def get_differentiation_matrix(self, axes, **kwargs): 

1872 """ 

1873 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1874 

1875 Args: 

1876 axes (tuple): Axes along which to differentiate. 

1877 

1878 Returns: 

1879 sparse differentiation matrix 

1880 """ 

1881 sp = self.sparse_lib 

1882 ndim = self.ndim 

1883 

1884 if ndim == 1: 

1885 D = self.axes[0].get_differentiation_matrix(**kwargs) 

1886 elif ndim == 2: 

1887 for axis in axes: 

1888 axis2 = (axis + 1) % ndim 

1889 D1D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1890 

1891 if len(axes) > 1: 

1892 I1D = sp.eye(self.axes[axis2].N) 

1893 else: 

1894 I1D = self.axes[axis2].get_Id() 

1895 

1896 mats = [None] * ndim 

1897 mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis) 

1898 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1899 

1900 if axis == axes[0]: 

1901 D = sp.kron(*mats) 

1902 else: 

1903 D = D @ sp.kron(*mats) 

1904 else: 

1905 raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!') 

1906 

1907 return D 

1908 

1909 def get_integration_matrix(self, axes): 

1910 """ 

1911 Get integration matrix to integrate along specified axis. 

1912 

1913 Args: 

1914 axes (tuple): Axes along which to integrate over. 

1915 

1916 Returns: 

1917 sparse integration matrix 

1918 """ 

1919 sp = self.sparse_lib 

1920 ndim = len(self.axes) 

1921 

1922 if ndim == 1: 

1923 S = self.axes[0].get_integration_matrix() 

1924 elif ndim == 2: 

1925 for axis in axes: 

1926 axis2 = (axis + 1) % ndim 

1927 S1D = self.axes[axis].get_integration_matrix() 

1928 

1929 if len(axes) > 1: 

1930 I1D = sp.eye(self.axes[axis2].N) 

1931 else: 

1932 I1D = self.axes[axis2].get_Id() 

1933 

1934 mats = [None] * ndim 

1935 mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis) 

1936 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

1937 

1938 if axis == axes[0]: 

1939 S = sp.kron(*mats) 

1940 else: 

1941 S = S @ sp.kron(*mats) 

1942 else: 

1943 raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!') 

1944 

1945 return S 

1946 

1947 def get_Id(self): 

1948 """ 

1949 Get identity matrix 

1950 

1951 Returns: 

1952 sparse identity matrix 

1953 """ 

1954 sp = self.sparse_lib 

1955 ndim = self.ndim 

1956 I = sp.eye(np.prod(self.init[0][1:]), dtype=complex) 

1957 

1958 if ndim == 1: 

1959 I = self.axes[0].get_Id() 

1960 elif ndim == 2: 

1961 for axis in range(ndim): 

1962 axis2 = (axis + 1) % ndim 

1963 I1D = self.axes[axis].get_Id() 

1964 

1965 I1D2 = sp.eye(self.axes[axis2].N) 

1966 

1967 mats = [None] * ndim 

1968 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1969 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2) 

1970 

1971 I = I @ sp.kron(*mats) 

1972 else: 

1973 raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!') 

1974 

1975 return I 

1976 

1977 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1978 """ 

1979 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1980 

1981 Args: 

1982 axis (int): Axis you discretized with Chebychev 

1983 

1984 Returns: 

1985 sparse matrix 

1986 """ 

1987 sp = self.sparse_lib 

1988 ndim = len(self.axes) 

1989 

1990 if ndim == 1: 

1991 C = self.axes[0].get_Dirichlet_recombination_matrix() 

1992 elif ndim == 2: 

1993 axis2 = (axis + 1) % ndim 

1994 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1995 

1996 I1D = self.axes[axis2].get_Id() 

1997 

1998 mats = [None] * ndim 

1999 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) 

2000 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

2001 

2002 C = sp.kron(*mats) 

2003 else: 

2004 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') 

2005 

2006 return C 

2007 

2008 def get_basis_change_matrix(self, axes=None, **kwargs): 

2009 """ 

2010 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

2011 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

2012 

2013 Args: 

2014 axes (tuple): Axes along which to change basis. 

2015 

2016 Returns: 

2017 sparse basis change matrix 

2018 """ 

2019 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

2020 

2021 sp = self.sparse_lib 

2022 ndim = len(self.axes) 

2023 

2024 if ndim == 1: 

2025 C = self.axes[0].get_basis_change_matrix(**kwargs) 

2026 elif ndim == 2: 

2027 for axis in axes: 

2028 axis2 = (axis + 1) % ndim 

2029 C1D = self.axes[axis].get_basis_change_matrix(**kwargs) 

2030 

2031 if len(axes) > 1: 

2032 I1D = sp.eye(self.axes[axis2].N) 

2033 else: 

2034 I1D = self.axes[axis2].get_Id() 

2035 

2036 mats = [None] * ndim 

2037 mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) 

2038 mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) 

2039 

2040 if axis == axes[0]: 

2041 C = sp.kron(*mats) 

2042 else: 

2043 C = C @ sp.kron(*mats) 

2044 else: 

2045 raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') 

2046 

2047 return C