Coverage for pySDC/helpers/spectral_helper.py: 93%

743 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 07:24 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5from functools import wraps 

6 

7 

8def cache(func): 

9 """ 

10 Decorator for caching return values of functions. 

11 This is very similar to `functools.cache`, but without the memory leaks (see 

12 https://docs.astral.sh/ruff/rules/cached-instance-method/). 

13 

14 Example: 

15 

16 .. code-block:: python 

17 

18 num_calls = 0 

19 

20 @cache 

21 def increment(x): 

22 num_calls += 1 

23 return x + 1 

24 

25 increment(0) # returns 1, num_calls = 1 

26 increment(1) # returns 2, num_calls = 2 

27 increment(0) # returns 1, num_calls = 2 

28 

29 

30 Args: 

31 func (function): The function you want to cache the return value of 

32 

33 Returns: 

34 return value of func 

35 """ 

36 attr_cache = f"_{func.__name__}_cache" 

37 

38 @wraps(func) 

39 def wrapper(self, *args, **kwargs): 

40 if not hasattr(self, attr_cache): 

41 setattr(self, attr_cache, {}) 

42 

43 cache = getattr(self, attr_cache) 

44 

45 key = (args, frozenset(kwargs.items())) 

46 if key in cache: 

47 return cache[key] 

48 result = func(self, *args, **kwargs) 

49 cache[key] = result 

50 return result 

51 

52 return wrapper 

53 

54 

55class SpectralHelper1D: 

56 """ 

57 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

58 all bases need to have. 

59 

60 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

61 the code for GPUs. 

62 

63 Attributes: 

64 N (int): Resolution 

65 x0 (float): Coordinate of left boundary 

66 x1 (float): Coordinate of right boundary 

67 L (float): Length of the domain 

68 useGPU (bool): Whether to use GPUs 

69 

70 """ 

71 

72 fft_lib = scipy.fft 

73 sparse_lib = scipy.sparse 

74 linalg = scipy.sparse.linalg 

75 xp = np 

76 

77 def __init__(self, N, x0=None, x1=None, useGPU=False): 

78 """ 

79 Constructor 

80 

81 Args: 

82 N (int): Resolution 

83 x0 (float): Coordinate of left boundary 

84 x1 (float): Coordinate of right boundary 

85 useGPU (bool): Whether to use GPUs 

86 """ 

87 self.N = N 

88 self.x0 = x0 

89 self.x1 = x1 

90 self.L = x1 - x0 

91 self.useGPU = useGPU 

92 

93 if useGPU: 

94 self.setup_GPU() 

95 

96 @classmethod 

97 def setup_GPU(cls): 

98 """switch to GPU modules""" 

99 import cupy as cp 

100 import cupyx.scipy.sparse as sparse_lib 

101 import cupyx.scipy.sparse.linalg as linalg 

102 import cupyx.scipy.fft as fft_lib 

103 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

104 

105 cls.xp = cp 

106 cls.sparse_lib = sparse_lib 

107 cls.linalg = linalg 

108 cls.fft_lib = fft_lib 

109 

110 def get_Id(self): 

111 """ 

112 Get identity matrix 

113 

114 Returns: 

115 sparse diagonal identity matrix 

116 """ 

117 return self.sparse_lib.eye(self.N) 

118 

119 def get_zero(self): 

120 """ 

121 Get a matrix with all zeros of the correct size. 

122 

123 Returns: 

124 sparse matrix with zeros everywhere 

125 """ 

126 return 0 * self.get_Id() 

127 

128 def get_differentiation_matrix(self): 

129 raise NotImplementedError() 

130 

131 def get_integration_matrix(self): 

132 raise NotImplementedError() 

133 

134 def get_wavenumbers(self): 

135 """ 

136 Get the grid in spectral space 

137 """ 

138 raise NotImplementedError 

139 

140 def get_empty_operator_matrix(self, S, O): 

141 """ 

142 Return a matrix of operators to be filled with the connections between the solution components. 

143 

144 Args: 

145 S (int): Number of components in the solution 

146 O (sparse matrix): Zero matrix used for initialization 

147 

148 Returns: 

149 list of lists containing sparse zeros 

150 """ 

151 return [[O for _ in range(S)] for _ in range(S)] 

152 

153 def get_basis_change_matrix(self, *args, **kwargs): 

154 """ 

155 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

156 between the various bases. 

157 

158 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

159 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

160 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

161 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

162 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

163 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

164 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

165 

166 Returns: 

167 sparse bases change matrix 

168 """ 

169 return self.sparse_lib.eye(self.N) 

170 

171 def get_BC(self, kind): 

172 """ 

173 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

174 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

175 the matrix are replaced by the boundary conditions as obtained by this method. 

176 

177 Args: 

178 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

179 individual 1D bases for what is implemented 

180 

181 Returns: 

182 self.xp.array: Boundary condition 

183 """ 

184 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

185 

186 def get_filter_matrix(self, kmin=0, kmax=None): 

187 """ 

188 Get a bandpass filter. 

189 

190 Args: 

191 kmin (int): Lower limit of the bandpass filter 

192 kmax (int): Upper limit of the bandpass filter 

193 

194 Returns: 

195 sparse matrix 

196 """ 

197 

198 k = abs(self.get_wavenumbers()) 

199 

200 kmax = max(k) if kmax is None else kmax 

201 

202 mask = self.xp.logical_or(k >= kmax, k < kmin) 

203 

204 if self.useGPU: 

205 Id = self.get_Id().get() 

206 else: 

207 Id = self.get_Id() 

208 F = Id.tolil() 

209 F[:, mask] = 0 

210 return F.tocsc() 

211 

212 def get_1dgrid(self): 

213 """ 

214 Get the grid in physical space 

215 

216 Returns: 

217 self.xp.array: Grid 

218 """ 

219 raise NotImplementedError 

220 

221 

222class ChebychevHelper(SpectralHelper1D): 

223 """ 

224 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

225 between physical and spectral space by discrete cosine transform. 

226 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

227 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

228 be formulated in first order formulation. 

229 

230 This implementation is largely based on the Dedalus paper (arXiv:1905.10388). 

231 """ 

232 

233 def __init__(self, *args, transform_type='fft', x0=-1, x1=1, **kwargs): 

234 """ 

235 Constructor. 

236 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

237 may choose to run on GPUs via the `useGPU` argument. 

238 

239 Args: 

240 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

241 use the FFT from the library to compute the DCT 

242 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

243 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

244 """ 

245 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

246 self.lin_trf_fac = (x1 - x0) / 2 

247 self.lin_trf_off = (x1 + x0) / 2 

248 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

249 self.transform_type = transform_type 

250 

251 if self.transform_type == 'fft': 

252 self.get_fft_utils() 

253 

254 self.norm = self.get_norm() 

255 

256 def get_1dgrid(self): 

257 ''' 

258 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

259 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

260 need to do an affine transformation before any Chebychev business. 

261 

262 Returns: 

263 numpy.ndarray: 1D grid 

264 ''' 

265 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

266 

267 def get_wavenumbers(self): 

268 """Get the domain in spectral space""" 

269 return self.xp.arange(self.N) 

270 

271 @cache 

272 def get_conv(self, name, N=None): 

273 ''' 

274 Get conversion matrix between different kinds of polynomials. The supported kinds are 

275 - T: Chebychev polynomials of first kind 

276 - U: Chebychev polynomials of second kind 

277 - D: Dirichlet recombination. 

278 

279 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

280 Once generates matrices are cached. So feel free to call the method as often as you like. 

281 

282 Args: 

283 name (str): Conversion code, e.g. 'T2U' 

284 N (int): Size of the matrix (optional) 

285 

286 Returns: 

287 scipy.sparse: Sparse conversion matrix 

288 ''' 

289 N = N if N else self.N 

290 sp = self.sparse_lib 

291 xp = self.xp 

292 

293 def get_forward_conv(name): 

294 if name == 'T2U': 

295 mat = (sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2)) / 2.0 

296 mat[:, 0] *= 2 

297 elif name == 'D2T': 

298 mat = sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

299 elif name[0] == name[-1]: 

300 mat = self.sparse_lib.eye(self.N) 

301 else: 

302 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

303 return mat 

304 

305 try: 

306 mat = get_forward_conv(name) 

307 except NotImplementedError as E: 

308 try: 

309 fwd = get_forward_conv(name[::-1]) 

310 import scipy.sparse as sp 

311 

312 if self.sparse_lib == sp: 

313 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

314 else: 

315 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

316 except NotImplementedError: 

317 raise NotImplementedError from E 

318 

319 return mat 

320 

321 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

322 """ 

323 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

324 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

325 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

326 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

327 

328 Args: 

329 conv (str): Conversion code, i.e. T2U 

330 

331 Returns: 

332 Sparse conversion matrix 

333 """ 

334 return self.get_conv(conv) 

335 

336 def get_integration_matrix(self, lbnd=0): 

337 """ 

338 Get matrix for integration 

339 

340 Args: 

341 lbnd (float): Lower bound for integration, only 0 is currently implemented 

342 

343 Returns: 

344 Sparse integration matrix 

345 """ 

346 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

347 n = self.xp.arange(self.N) 

348 if lbnd == 0: 

349 S = S.tocsc() 

350 S[0, 1::2] = ( 

351 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

352 * (-1) ** (self.xp.arange(self.N // 2)) 

353 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

354 ) * self.lin_trf_fac 

355 else: 

356 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

357 return S 

358 

359 def get_differentiation_matrix(self, p=1): 

360 ''' 

361 Keep in mind that the T2T differentiation matrix is dense. 

362 

363 Args: 

364 p (int): Derivative you want to compute 

365 

366 Returns: 

367 numpy.ndarray: Differentiation matrix 

368 ''' 

369 D = self.xp.zeros((self.N, self.N)) 

370 for j in range(self.N): 

371 for k in range(j): 

372 D[k, j] = 2 * j * ((j - k) % 2) 

373 

374 D[0, :] /= 2 

375 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

376 

377 def get_norm(self, N=None): 

378 ''' 

379 Get normalization for converting Chebychev coefficients and DCT 

380 

381 Args: 

382 N (int, optional): Resolution 

383 

384 Returns: 

385 self.xp.array: Normalization 

386 ''' 

387 N = self.N if N is None else N 

388 norm = self.xp.ones(N) / N 

389 norm[0] /= 2 

390 return norm 

391 

392 def get_fft_shuffle(self, forward, N): 

393 """ 

394 In order to more easily parallelize using distributed FFT libraries, we express the DCT via an FFT following 

395 doi.org/10.1109/TASSP.1980.1163351. The idea is based on reshuffling the data to be periodic and rotating it 

396 in the complex plane. This function returns a mask to do the shuffling. 

397 

398 Args: 

399 forward (bool): Whether you want the shuffle for forward transform or backward transform 

400 N (int): size of the grid 

401 

402 Returns: 

403 self.xp.array: Use as mask 

404 """ 

405 xp = self.xp 

406 if forward: 

407 return xp.append(xp.arange((N + 1) // 2) * 2, -xp.arange(N // 2) * 2 - 1 - N % 2) 

408 else: 

409 mask = xp.zeros(N, dtype=int) 

410 mask[: N - N % 2 : 2] = xp.arange(N // 2) 

411 mask[1::2] = N - xp.arange(N // 2) - 1 

412 mask[-1] = N // 2 

413 return mask 

414 

415 def get_fft_shift(self, forward, N): 

416 """ 

417 As described in the docstring for `get_fft_shuffle`, we need to rotate in the complex plane in order to use FFT for DCT. 

418 

419 Args: 

420 forward (bool): Whether you want the rotation for forward transform or backward transform 

421 N (int): size of the grid 

422 

423 Returns: 

424 self.xp.array: Rotation 

425 """ 

426 k = self.get_wavenumbers() 

427 norm = self.get_norm() 

428 xp = self.xp 

429 if forward: 

430 return 2 * xp.exp(-1j * np.pi * k / (2 * N) + 0j * np.pi / 4) * norm 

431 else: 

432 shift = xp.exp(1j * np.pi * k / (2 * N)) 

433 shift[0] = 0.5 

434 return shift / norm 

435 

436 def get_fft_utils(self): 

437 """ 

438 Get the required utilities for using FFT to do DCT as described in the docstring for `get_fft_shuffle` and keep 

439 them cached. 

440 """ 

441 self.fft_utils = { 

442 'fwd': {}, 

443 'bck': {}, 

444 } 

445 

446 # forwards transform 

447 self.fft_utils['fwd']['shuffle'] = self.get_fft_shuffle(True, self.N) 

448 self.fft_utils['fwd']['shift'] = self.get_fft_shift(True, self.N) 

449 

450 # backwards transform 

451 self.fft_utils['bck']['shuffle'] = self.get_fft_shuffle(False, self.N) 

452 self.fft_utils['bck']['shift'] = self.get_fft_shift(False, self.N) 

453 

454 return self.fft_utils 

455 

456 def transform(self, u, axis=-1, **kwargs): 

457 """ 

458 1D DCT along axis. `kwargs` will be passed on to the FFT library. 

459 

460 Args: 

461 u: Data you want to transform 

462 axis (int): Axis you want to transform along 

463 

464 Returns: 

465 Data in spectral space 

466 """ 

467 if self.transform_type.lower() == 'dct': 

468 return self.fft_lib.dct(u, axis=axis, **kwargs) * self.norm 

469 elif self.transform_type.lower() == 'fft': 

470 result = u.copy() 

471 

472 shuffle = [slice(0, s, 1) for s in u.shape] 

473 shuffle[axis] = self.fft_utils['fwd']['shuffle'] 

474 

475 v = u[(*shuffle,)] 

476 

477 V = self.fft_lib.fft(v, axis=axis, **kwargs) 

478 

479 expansion = [np.newaxis for _ in u.shape] 

480 expansion[axis] = slice(0, u.shape[axis], 1) 

481 

482 V *= self.fft_utils['fwd']['shift'][(*expansion,)] 

483 

484 result.real[...] = V.real[...] 

485 return result 

486 else: 

487 raise NotImplementedError(f'Please choose a transform type from fft and dct, not {self.transform_type=}') 

488 

489 def itransform(self, u, axis=-1): 

490 """ 

491 1D inverse DCT along axis. 

492 

493 Args: 

494 u: Data you want to transform 

495 axis (int): Axis you want to transform along 

496 

497 Returns: 

498 Data in physical space 

499 """ 

500 assert self.norm.shape[0] == u.shape[axis] 

501 

502 if self.transform_type == 'dct': 

503 return self.fft_lib.idct(u / self.norm, axis=axis) 

504 elif self.transform_type == 'fft': 

505 result = u.copy() 

506 

507 expansion = [np.newaxis for _ in u.shape] 

508 expansion[axis] = slice(0, u.shape[axis], 1) 

509 

510 v = self.fft_lib.ifft(u * self.fft_utils['bck']['shift'][(*expansion,)], axis=axis) 

511 

512 shuffle = [slice(0, s, 1) for s in u.shape] 

513 shuffle[axis] = self.fft_utils['bck']['shuffle'] 

514 V = v[(*shuffle,)] 

515 

516 result.real[...] = V.real[...] 

517 return result 

518 else: 

519 raise NotImplementedError 

520 

521 def get_BC(self, kind, **kwargs): 

522 """ 

523 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

524 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

525 set the BC. 

526 

527 Args: 

528 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

529 """ 

530 if kind.lower() == 'integral': 

531 return self.get_integ_BC_row(**kwargs) 

532 elif kind.lower() == 'dirichlet': 

533 return self.get_Dirichlet_BC_row(**kwargs) 

534 else: 

535 return super().get_BC(kind) 

536 

537 def get_integ_BC_row(self): 

538 """ 

539 Get a row for generating integral BCs with T polynomials. 

540 It returns the values of the integrals of T polynomials over the entire interval. 

541 

542 Returns: 

543 self.xp.ndarray: Row to put into a matrix 

544 """ 

545 n = self.xp.arange(self.N) + 1 

546 me = self.xp.zeros_like(n).astype(float) 

547 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

548 me[0] = 2.0 

549 return me 

550 

551 def get_Dirichlet_BC_row(self, x): 

552 """ 

553 Get a row for generating Dirichlet BCs at x with T polynomials. 

554 It returns the values of the T polynomials at x. 

555 

556 Args: 

557 x (float): Position of the boundary condition 

558 

559 Returns: 

560 self.xp.ndarray: Row to put into a matrix 

561 """ 

562 if x == -1: 

563 return (-1) ** self.xp.arange(self.N) 

564 elif x == 1: 

565 return self.xp.ones(self.N) 

566 elif x == 0: 

567 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

568 n[2::4] *= -1 

569 return n 

570 else: 

571 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

572 

573 def get_Dirichlet_recombination_matrix(self): 

574 ''' 

575 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

576 This makes for a good right preconditioner. 

577 

578 Returns: 

579 scipy.sparse: Sparse conversion matrix 

580 ''' 

581 N = self.N 

582 sp = self.sparse_lib 

583 xp = self.xp 

584 

585 return sp.eye(N) - sp.diags(xp.ones(N - 2), offsets=+2) 

586 

587 

588class UltrasphericalHelper(ChebychevHelper): 

589 """ 

590 This implementation follows https://doi.org/10.1137/120865458. 

591 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

592 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

593 There are also conversion operators from one derivative basis to the next that are sparse. 

594 

595 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

596 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

597 """ 

598 

599 def get_differentiation_matrix(self, p=1): 

600 """ 

601 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

602 

603 Args: 

604 p (int): Order of the derivative 

605 

606 Returns: 

607 sparse differentiation matrix 

608 """ 

609 sp = self.sparse_lib 

610 xp = self.xp 

611 N = self.N 

612 l = p 

613 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

614 

615 def get_S(self, lmbda): 

616 """ 

617 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

618 https://doi.org/10.1137/120865458. 

619 

620 Args: 

621 lmbda (int): Ingoing derivative base 

622 

623 Returns: 

624 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

625 """ 

626 N = self.N 

627 

628 if lmbda == 0: 

629 sp = scipy.sparse 

630 mat = ((sp.eye(N) - sp.diags(np.ones(N - 2), offsets=+2)) / 2.0).tolil() 

631 mat[:, 0] *= 2 

632 else: 

633 sp = self.sparse_lib 

634 xp = self.xp 

635 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

636 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

637 ) 

638 

639 return self.sparse_lib.csc_matrix(mat) 

640 

641 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

642 """ 

643 Get a conversion matrix from derivative base `p_in` to `p_out`. 

644 

645 Args: 

646 p_out (int): Resulting derivative base 

647 p_in (int): Ingoing derivative base 

648 """ 

649 mat_fwd = self.sparse_lib.eye(self.N) 

650 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

651 mat_fwd = self.get_S(i) @ mat_fwd 

652 

653 if p_out > p_in: 

654 return mat_fwd 

655 

656 else: 

657 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

658 import scipy.sparse as sp 

659 

660 if self.useGPU: 

661 mat_fwd = mat_fwd.get() 

662 

663 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

664 

665 return self.sparse_lib.csc_matrix(mat_bck) 

666 

667 def get_integration_matrix(self): 

668 """ 

669 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

670 integration constant such that integration starts from x=-1. 

671 

672 Example: 

673 

674 .. code-block:: python 

675 

676 import numpy as np 

677 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

678 

679 N = 4 

680 helper = UltrasphericalHelper(N) 

681 coeffs = np.random.random(N) 

682 coeffs[-1] = 0 

683 

684 poly = np.polynomial.Chebyshev(coeffs) 

685 

686 S = helper.get_integration_matrix() 

687 U_hat = S @ coeffs 

688 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

689 

690 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

691 

692 Returns: 

693 sparse integration matrix 

694 """ 

695 return ( 

696 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

697 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

698 * self.lin_trf_fac 

699 ) 

700 

701 def get_integration_constant(self, u_hat, axis): 

702 """ 

703 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

704 

705 Args: 

706 u_hat: Solution in spectral space 

707 axis: Axis you want to integrate over 

708 

709 Returns: 

710 Integration constant, has one less dimension than `u_hat` 

711 """ 

712 slices = [ 

713 None, 

714 ] * u_hat.ndim 

715 slices[axis] = slice(1, u_hat.shape[axis]) 

716 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

717 

718 

719class FFTHelper(SpectralHelper1D): 

720 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

721 """ 

722 Constructor. 

723 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

724 may choose to run on GPUs via the `useGPU` argument. 

725 

726 Args: 

727 transform_type ('fft' or 'dct'): Either use DCT functions directly implemented in the transform library or 

728 use the FFT from the library to compute the DCT 

729 x0 (float, optional): Coordinate of left boundary 

730 x1 (float, optional): Coordinate of right boundary 

731 """ 

732 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

733 

734 def get_1dgrid(self): 

735 """ 

736 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

737 """ 

738 dx = self.L / self.N 

739 return self.xp.arange(self.N) * dx + self.x0 

740 

741 def get_wavenumbers(self): 

742 """ 

743 Be careful that this ordering is very unintuitive. 

744 """ 

745 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

746 

747 def get_differentiation_matrix(self, p=1): 

748 """ 

749 This matrix is diagonal, allowing to invert concurrently. 

750 

751 Args: 

752 p (int): Order of the derivative 

753 

754 Returns: 

755 sparse differentiation matrix 

756 """ 

757 k = self.get_wavenumbers() 

758 

759 if self.useGPU: 

760 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

761 import scipy.sparse as sp 

762 

763 D = self.sparse_lib.diags(1j * k).get() 

764 return self.sparse_lib.csc_matrix(sp.linalg.matrix_power(D, p)) 

765 else: 

766 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

767 

768 def get_integration_matrix(self, p=1): 

769 """ 

770 Get integration matrix to compute `p`-th integral over the entire domain. 

771 

772 Args: 

773 p (int): Order of integral you want to compute 

774 

775 Returns: 

776 sparse integration matrix 

777 """ 

778 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

779 k[0] = 1j * self.L 

780 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

781 

782 def transform(self, u, axis=-1, **kwargs): 

783 """ 

784 1D FFT along axis. `kwargs` are passed on to the FFT library. 

785 

786 Args: 

787 u: Data you want to transform 

788 axis (int): Axis you want to transform along 

789 

790 Returns: 

791 transformed data 

792 """ 

793 return self.fft_lib.fft(u, axis=axis, **kwargs) 

794 

795 def itransform(self, u, axis=-1): 

796 """ 

797 Inverse 1D FFT. 

798 

799 Args: 

800 u: Data you want to transform 

801 axis (int): Axis you want to transform along 

802 

803 Returns: 

804 transformed data 

805 """ 

806 return self.fft_lib.ifft(u, axis=axis) 

807 

808 def get_BC(self, kind): 

809 """ 

810 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

811 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

812 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

813 

814 Args: 

815 kind ('integral' or 'nyquist'): Kind of BC 

816 

817 Returns: 

818 self.xp.ndarray: Boundary condition row 

819 """ 

820 if kind.lower() == 'integral': 

821 return self.get_integ_BC_row() 

822 elif kind.lower() == 'nyquist': 

823 assert ( 

824 self.N % 2 == 0 

825 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

826 BC = self.xp.zeros(self.N) 

827 BC[self.get_Nyquist_mode_index()] = 1 

828 return BC 

829 else: 

830 return super().get_BC(kind) 

831 

832 def get_Nyquist_mode_index(self): 

833 """ 

834 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

835 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

836 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

837 after. 

838 

839 Returns: 

840 int: Index of the Nyquist mode 

841 """ 

842 k = self.get_wavenumbers() 

843 Nyquist_mode = min(k) 

844 return self.xp.where(k == Nyquist_mode)[0][0] 

845 

846 def get_integ_BC_row(self): 

847 """ 

848 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

849 """ 

850 me = self.xp.zeros(self.N) 

851 me[0] = self.L / self.N 

852 return me 

853 

854 

855class SpectralHelper: 

856 """ 

857 This class has three functions: 

858 - Easily assemble matrices containing multiple equations 

859 - Direct product of 1D bases to solve problems in more dimensions 

860 - Distribute the FFTs to facilitate concurrency. 

861 

862 Attributes: 

863 comm (mpi4py.Intracomm): MPI communicator 

864 debug (bool): Perform additional checks at extra computational cost 

865 useGPU (bool): Whether to use GPUs 

866 axes (list): List of 1D bases 

867 components (list): List of strings of the names of components in the equations 

868 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

869 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

870 BCs (sparse matrix): Matrix containing only the BCs 

871 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

872 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

873 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

874 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

875 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

876 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

877 local_slice (slice): Local slice of the solution as in `mpi4py-fft` 

878 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

879 init (tuple): This is the same `init` that is used throughout the problem classes 

880 init_forward (tuple): This is the equivalent of `init` in spectral space 

881 """ 

882 

883 xp = np 

884 fft_lib = scipy.fft 

885 sparse_lib = scipy.sparse 

886 linalg = scipy.sparse.linalg 

887 dtype = mesh 

888 fft_backend = 'fftw' 

889 fft_comm_backend = 'MPI' 

890 

891 @classmethod 

892 def setup_GPU(cls): 

893 """switch to GPU modules""" 

894 import cupy as cp 

895 import cupyx.scipy.sparse as sparse_lib 

896 import cupyx.scipy.sparse.linalg as linalg 

897 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

898 

899 cls.xp = cp 

900 cls.sparse_lib = sparse_lib 

901 cls.linalg = linalg 

902 

903 cls.fft_backend = 'cupy' 

904 cls.fft_comm_backend = 'NCCL' 

905 

906 cls.dtype = cupy_mesh 

907 

908 def __init__(self, comm=None, useGPU=False, debug=False): 

909 """ 

910 Constructor 

911 

912 Args: 

913 comm (mpi4py.Intracomm): MPI communicator 

914 useGPU (bool): Whether to use GPUs 

915 debug (bool): Perform additional checks at extra computational cost 

916 """ 

917 self.comm = comm 

918 self.debug = debug 

919 self.useGPU = useGPU 

920 

921 if useGPU: 

922 self.setup_GPU() 

923 

924 self.axes = [] 

925 self.components = [] 

926 

927 self.full_BCs = [] 

928 self.BC_mat = None 

929 self.BCs = None 

930 

931 self.fft_cache = {} 

932 self.fft_dealias_shape_cache = {} 

933 

934 @property 

935 def u_init(self): 

936 """ 

937 Get empty data container in physical space 

938 """ 

939 return self.dtype(self.init) 

940 

941 @property 

942 def u_init_forward(self): 

943 """ 

944 Get empty data container in spectral space 

945 """ 

946 return self.dtype(self.init_forward) 

947 

948 @property 

949 def shape(self): 

950 """ 

951 Get shape of individual solution component 

952 """ 

953 return self.init[0][1:] 

954 

955 @property 

956 def ndim(self): 

957 return len(self.axes) 

958 

959 @property 

960 def ncomponents(self): 

961 return len(self.components) 

962 

963 @property 

964 def V(self): 

965 """ 

966 Get domain volume 

967 """ 

968 return np.prod([me.L for me in self.axes]) 

969 

970 def add_axis(self, base, *args, **kwargs): 

971 """ 

972 Add an axis to the domain by deciding on suitable 1D base. 

973 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

974 bases for possible arguments. 

975 

976 Args: 

977 base (str): 1D spectral method 

978 """ 

979 kwargs['useGPU'] = self.useGPU 

980 

981 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

982 kwargs['transform_type'] = kwargs.get('transform_type', 'fft') 

983 self.axes.append(ChebychevHelper(*args, **kwargs)) 

984 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

985 self.axes.append(FFTHelper(*args, **kwargs)) 

986 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

987 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

988 else: 

989 raise NotImplementedError(f'{base=!r} is not implemented!') 

990 self.axes[-1].xp = self.xp 

991 self.axes[-1].sparse_lib = self.sparse_lib 

992 

993 def add_component(self, name): 

994 """ 

995 Add solution component(s). 

996 

997 Args: 

998 name (str or list of strings): Name(s) of component(s) 

999 """ 

1000 if type(name) in [list, tuple]: 

1001 for me in name: 

1002 self.add_component(me) 

1003 elif type(name) in [str]: 

1004 if name in self.components: 

1005 raise Exception(f'{name=!r} is already added to this problem!') 

1006 self.components.append(name) 

1007 else: 

1008 raise NotImplementedError 

1009 

1010 def index(self, name): 

1011 """ 

1012 Get the index of component `name`. 

1013 

1014 Args: 

1015 name (str or list of strings): Name(s) of component(s) 

1016 

1017 Returns: 

1018 int: Index of the component 

1019 """ 

1020 if type(name) in [str, int]: 

1021 return self.components.index(name) 

1022 elif type(name) in [list, tuple]: 

1023 return (self.index(me) for me in name) 

1024 else: 

1025 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

1026 

1027 def get_empty_operator_matrix(self): 

1028 """ 

1029 Return a matrix of operators to be filled with the connections between the solution components. 

1030 

1031 Returns: 

1032 list containing sparse zeros 

1033 """ 

1034 S = len(self.components) 

1035 O = self.get_Id() * 0 

1036 return [[O for _ in range(S)] for _ in range(S)] 

1037 

1038 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

1039 """ 

1040 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

1041 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

1042 Typically, the last line that does not contain a BC is the best choice. 

1043 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1044 

1045 Args: 

1046 axis (int): Axis you want to add the BC to 

1047 kind (str): kind of BC, e.g. Dirichlet 

1048 line (int): Line you want the BC to go in 

1049 scalar (bool): Put the BC in all space positions in the other direction 

1050 

1051 Returns: 

1052 sparse matrix containing the BC 

1053 """ 

1054 sp = scipy.sparse 

1055 

1056 base = self.axes[axis] 

1057 

1058 BC = sp.eye(base.N).tolil() * 0 

1059 if self.useGPU: 

1060 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1061 else: 

1062 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1063 

1064 ndim = len(self.axes) 

1065 if ndim == 1: 

1066 return self.sparse_lib.csc_matrix(BC) 

1067 elif ndim == 2: 

1068 axis2 = (axis + 1) % ndim 

1069 

1070 if scalar: 

1071 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1072 else: 

1073 _Id = self.axes[axis2].get_Id() 

1074 

1075 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1076 

1077 if self.useGPU: 

1078 Id = Id.get() 

1079 

1080 mats = [ 

1081 None, 

1082 ] * ndim 

1083 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1084 mats[axis2] = Id 

1085 return self.sparse_lib.csc_matrix(sp.kron(*mats)) 

1086 

1087 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1088 """ 

1089 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1090 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1091 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1092 

1093 Args: 

1094 component (str): Name of the component the BC should act on 

1095 equation (str): Name of the equation for the component you want to put the BC in 

1096 axis (int): Axis you want to add the BC to 

1097 kind (str): kind of BC, e.g. Dirichlet 

1098 v: Value of the BC 

1099 line (int): Line you want the BC to go in 

1100 scalar (bool): Put the BC in all space positions in the other direction 

1101 """ 

1102 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1103 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1104 

1105 if scalar: 

1106 slices = [self.index(equation)] + [ 

1107 0, 

1108 ] * self.ndim 

1109 slices[axis + 1] = line 

1110 else: 

1111 slices = ( 

1112 [self.index(equation)] 

1113 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1114 + [line] 

1115 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1116 ) 

1117 N = self.axes[axis].N 

1118 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1119 self.BC_rhs_mask[(*slices,)] = False 

1120 

1121 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1122 """ 

1123 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1124 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1125 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1126 

1127 Args: 

1128 component (str): Name of the component the BC should act on 

1129 equation (str): Name of the equation for the component you want to put the BC in 

1130 axis (int): Axis you want to add the BC to 

1131 kind (str): kind of BC, e.g. Dirichlet 

1132 v: Value of the BC 

1133 line (int): Line you want the BC to go in 

1134 scalar (bool): Put the BC in all space positions in the other direction 

1135 """ 

1136 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1137 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1138 self.full_BCs += [ 

1139 { 

1140 'component': component, 

1141 'equation': equation, 

1142 'axis': axis, 

1143 'kind': kind, 

1144 'v': v, 

1145 'line': line, 

1146 'scalar': scalar, 

1147 **kwargs, 

1148 } 

1149 ] 

1150 

1151 if scalar: 

1152 slices = [self.index(equation)] + [ 

1153 0, 

1154 ] * self.ndim 

1155 slices[axis + 1] = line 

1156 if self.comm: 

1157 if self.comm.rank == 0: 

1158 self.BC_rhs_mask[(*slices,)] = True 

1159 else: 

1160 self.BC_rhs_mask[(*slices,)] = True 

1161 else: 

1162 slices = ( 

1163 [self.index(equation)] 

1164 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1165 + [line] 

1166 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1167 ) 

1168 N = self.axes[axis].N 

1169 if (N + line) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1170 slices[axis + 1] -= self.local_slice[axis].start 

1171 self.BC_rhs_mask[(*slices,)] = True 

1172 

1173 def setup_BCs(self): 

1174 """ 

1175 Convert the list of lists of BCs to the boundary condition operator. 

1176 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1177 condition. This method sets up a suitable sparse matrix to do this. 

1178 """ 

1179 sp = self.sparse_lib 

1180 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1181 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1182 

1183 diags = self.xp.ones(self.BCs.shape[0]) 

1184 diags[self.BC_zero_index] = 0 

1185 self.BC_line_zero_matrix = sp.diags(diags) 

1186 

1187 # prepare BCs in spectral space to easily add to the RHS 

1188 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1189 self.rhs_BCs_hat = self.transform(rhs_BCs) 

1190 

1191 def check_BCs(self, u): 

1192 """ 

1193 Check that the solution satisfies the boundary conditions 

1194 

1195 Args: 

1196 u: The solution you want to check 

1197 """ 

1198 assert self.ndim < 3 

1199 for axis in range(self.ndim): 

1200 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1201 

1202 if len(BCs) > 0: 

1203 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1204 for BC in BCs: 

1205 kwargs = { 

1206 key: value 

1207 for key, value in BC.items() 

1208 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1209 } 

1210 

1211 if axis == 0: 

1212 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1213 elif axis == 1: 

1214 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1215 want = BC['v'] 

1216 assert self.xp.allclose( 

1217 get, want 

1218 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1219 

1220 def put_BCs_in_matrix(self, A): 

1221 """ 

1222 Put the boundary conditions in a matrix by replacing rows with BCs. 

1223 """ 

1224 return self.BC_line_zero_matrix @ A + self.BCs 

1225 

1226 def put_BCs_in_rhs_hat(self, rhs_hat): 

1227 """ 

1228 Put the BCs in the right hand side in spectral space for solving. 

1229 This function needs no transforms and caches a mask for faster subsequent use. 

1230 

1231 Args: 

1232 rhs_hat: Right hand side in spectral space 

1233 

1234 Returns: 

1235 rhs in spectral space with BCs 

1236 """ 

1237 if not hasattr(self, '_rhs_hat_zero_mask'): 

1238 """ 

1239 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1240 by the boundary conditions. The mask is then cached. 

1241 """ 

1242 self._rhs_hat_zero_mask = self.xp.zeros(shape=rhs_hat.shape, dtype=bool) 

1243 

1244 for axis in range(self.ndim): 

1245 for bc in self.full_BCs: 

1246 slices = ( 

1247 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1248 + [bc['line']] 

1249 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1250 ) 

1251 if axis == bc['axis']: 

1252 _slice = [self.index(bc['equation'])] + slices 

1253 N = self.axes[axis].N 

1254 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1255 _slice[axis + 1] -= self.local_slice[axis].start 

1256 self._rhs_hat_zero_mask[(*_slice,)] = True 

1257 

1258 rhs_hat[self._rhs_hat_zero_mask] = 0 

1259 return rhs_hat + self.rhs_BCs_hat 

1260 

1261 def put_BCs_in_rhs(self, rhs): 

1262 """ 

1263 Put the BCs in the right hand side for solving. 

1264 This function will transform along each axis individually and add all BCs in that axis. 

1265 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1266 

1267 Args: 

1268 rhs: Right hand side in physical space 

1269 

1270 Returns: 

1271 rhs in physical space with BCs 

1272 """ 

1273 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1274 

1275 ndim = self.ndim 

1276 

1277 for axis in range(ndim): 

1278 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1279 

1280 for bc in self.full_BCs: 

1281 slices = ( 

1282 [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1283 + [bc['line']] 

1284 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1285 ) 

1286 if axis == bc['axis']: 

1287 _slice = [self.index(bc['equation'])] + slices 

1288 

1289 N = self.axes[axis].N 

1290 if (N + bc['line']) % N in self.xp.arange(N)[self.local_slice[axis]]: 

1291 _slice[axis + 1] -= self.local_slice[axis].start 

1292 

1293 _rhs_hat[(*_slice,)] = bc['v'] 

1294 

1295 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1296 

1297 return rhs 

1298 

1299 def add_equation_lhs(self, A, equation, relations): 

1300 """ 

1301 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1302 that you will convert to an operator later. 

1303 

1304 Example: 

1305 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1306 preconditioning: 

1307 

1308 .. code-block:: python 

1309 helper = SpectralHelper() 

1310 

1311 helper.add_axis(base='chebychev', N=8) 

1312 helper.add_component(['u', 'ux']) 

1313 helper.setup_fft() 

1314 

1315 I = helper.get_Id() 

1316 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1317 T2U = helper.get_basis_change_matrix('T2U') 

1318 

1319 L_lhs = { 

1320 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1321 'u': {'ux': -(T2U @ Dx)}, 

1322 } 

1323 

1324 operator = helper.get_empty_operator_matrix() 

1325 for line, equation in L_lhs.items(): 

1326 helper.add_equation_lhs(operator, line, equation) 

1327 

1328 L = helper.convert_operator_matrix_to_operator(operator) 

1329 

1330 Args: 

1331 A (list of lists of sparse matrices): The operator to be 

1332 equation (str): The equation of the component you want this in 

1333 relations: (dict): Relations between quantities 

1334 """ 

1335 for k, v in relations.items(): 

1336 A[self.index(equation)][self.index(k)] = v 

1337 

1338 def convert_operator_matrix_to_operator(self, M): 

1339 """ 

1340 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1341 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1342 

1343 Args: 

1344 M (list of lists of sparse matrices): The operator to be 

1345 

1346 Returns: 

1347 sparse linear operator 

1348 """ 

1349 if len(self.components) == 1: 

1350 return M[0][0] 

1351 else: 

1352 return self.sparse_lib.bmat(M, format='csc') 

1353 

1354 def get_wavenumbers(self): 

1355 """ 

1356 Get grid in spectral space 

1357 """ 

1358 grids = [self.axes[i].get_wavenumbers()[self.local_slice[i]] for i in range(len(self.axes))] 

1359 return self.xp.meshgrid(*grids, indexing='ij') 

1360 

1361 def get_grid(self): 

1362 """ 

1363 Get grid in physical space 

1364 """ 

1365 grids = [self.axes[i].get_1dgrid()[self.local_slice[i]] for i in range(len(self.axes))] 

1366 return self.xp.meshgrid(*grids, indexing='ij') 

1367 

1368 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1369 """ 

1370 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1371 

1372 Args: 

1373 axes (tuple): Axes you want to transform over 

1374 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1375 padding (tuple): Padding for dealiasing 

1376 shape (tuple): Shape of the transform 

1377 

1378 Returns: 

1379 transform 

1380 """ 

1381 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1382 shape = self.global_shape[1:] if shape is None else shape 

1383 padding = ( 

1384 [ 

1385 1, 

1386 ] 

1387 * self.ndim 

1388 if padding is None 

1389 else padding 

1390 ) 

1391 key = (axes, direction, tuple(padding), tuple(shape)) 

1392 

1393 if key not in self.fft_cache.keys(): 

1394 if self.comm is None: 

1395 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1396 

1397 if direction == 'forward': 

1398 self.fft_cache[key] = self.xp.fft.fftn 

1399 elif direction == 'backward': 

1400 self.fft_cache[key] = self.xp.fft.ifftn 

1401 elif direction == 'object': 

1402 self.fft_cache[key] = None 

1403 else: 

1404 if direction == 'object': 

1405 from mpi4py_fft import PFFT 

1406 

1407 _fft = PFFT( 

1408 comm=self.comm, 

1409 shape=shape, 

1410 axes=sorted(axes), 

1411 dtype='D', 

1412 collapse=False, 

1413 backend=self.fft_backend, 

1414 comm_backend=self.fft_comm_backend, 

1415 padding=padding, 

1416 ) 

1417 else: 

1418 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1419 

1420 if direction == 'forward': 

1421 self.fft_cache[key] = _fft.forward 

1422 elif direction == 'backward': 

1423 self.fft_cache[key] = _fft.backward 

1424 elif direction == 'object': 

1425 self.fft_cache[key] = _fft 

1426 

1427 return self.fft_cache[key] 

1428 

1429 def setup_fft(self, real_spectral_coefficients=False): 

1430 """ 

1431 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1432 This must also be called before setting up any BCs. 

1433 

1434 Args: 

1435 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1436 """ 

1437 if len(self.components) == 0: 

1438 self.add_component('u') 

1439 

1440 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1441 self.local_slice = [slice(0, me.N) for me in self.axes] 

1442 

1443 axes = tuple(i for i in range(len(self.axes))) 

1444 self.fft_obj = self.get_fft(axes=axes, direction='object') 

1445 if self.fft_obj is not None: 

1446 self.local_slice = self.fft_obj.local_slice(False) 

1447 

1448 self.init = ( 

1449 np.empty(shape=self.global_shape)[ 

1450 ( 

1451 ..., 

1452 *self.local_slice, 

1453 ) 

1454 ].shape, 

1455 self.comm, 

1456 np.dtype('float'), 

1457 ) 

1458 self.init_forward = ( 

1459 np.empty(shape=self.global_shape)[ 

1460 ( 

1461 ..., 

1462 *self.local_slice, 

1463 ) 

1464 ].shape, 

1465 self.comm, 

1466 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1467 ) 

1468 

1469 self.BC_mat = self.get_empty_operator_matrix() 

1470 self.BC_rhs_mask = self.xp.zeros( 

1471 shape=self.init[0], 

1472 dtype=bool, 

1473 ) 

1474 

1475 def _transform_fft(self, u, axes, **kwargs): 

1476 """ 

1477 FFT along `axes` 

1478 

1479 Args: 

1480 u: The solution 

1481 axes (tuple): Axes you want to transform over 

1482 

1483 Returns: 

1484 transformed solution 

1485 """ 

1486 # TODO: clean up and try putting more of this in the 1D bases 

1487 fft = self.get_fft(axes, 'forward', **kwargs) 

1488 return fft(u, axes=axes) 

1489 

1490 def _transform_dct(self, u, axes, padding=None, **kwargs): 

1491 ''' 

1492 DCT along `axes`. 

1493 This will only return real values! 

1494 When padding the solution, we cannot just use the mpi4py-fft implementation, because of the unusual ordering of 

1495 wavenumbers in FFTs. 

1496 

1497 Args: 

1498 u: The solution 

1499 axes (tuple): Axes you want to transform over 

1500 

1501 Returns: 

1502 transformed solution 

1503 ''' 

1504 # TODO: clean up and try putting more of this in the 1D bases 

1505 if self.debug: 

1506 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1507 

1508 if len(axes) > 1: 

1509 v = self._transform_dct(self._transform_dct(u, axes[1:], **kwargs), (axes[0],), **kwargs) 

1510 else: 

1511 v = u.copy().astype(complex) 

1512 axis = axes[0] 

1513 base = self.axes[axis] 

1514 

1515 shuffle = [slice(0, s, 1) for s in u.shape] 

1516 shuffle[axis] = base.get_fft_shuffle(True, N=v.shape[axis]) 

1517 v = v[(*shuffle,)] 

1518 

1519 if padding is not None: 

1520 shape = list(v.shape) 

1521 if ('forward', *padding) in self.fft_dealias_shape_cache.keys(): 

1522 shape[0] = self.fft_dealias_shape_cache[('forward', *padding)] 

1523 elif self.comm: 

1524 send_buf = np.array(v.shape[0]) 

1525 recv_buf = np.array(v.shape[0]) 

1526 self.comm.Allreduce(send_buf, recv_buf) 

1527 shape[0] = int(recv_buf) 

1528 fft = self.get_fft(axes, 'forward', shape=shape) 

1529 else: 

1530 fft = self.get_fft(axes, 'forward', **kwargs) 

1531 

1532 v = fft(v, axes=axes) 

1533 

1534 expansion = [np.newaxis for _ in u.shape] 

1535 expansion[axis] = slice(0, v.shape[axis], 1) 

1536 

1537 if padding is not None: 

1538 shift = base.get_fft_shift(True, v.shape[axis]) 

1539 

1540 if padding[axis] != 1: 

1541 N = int(np.ceil(v.shape[axis] / padding[axis])) 

1542 _expansion = [slice(0, n) for n in v.shape] 

1543 _expansion[axis] = slice(0, N, 1) 

1544 v = v[(*_expansion,)] 

1545 else: 

1546 shift = base.fft_utils['fwd']['shift'] 

1547 

1548 v *= shift[(*expansion,)] 

1549 

1550 return v.real 

1551 

1552 def transform_single_component(self, u, axes=None, padding=None): 

1553 """ 

1554 Transform a single component of the solution 

1555 

1556 Args: 

1557 u data to transform: 

1558 axes (tuple): Axes over which to transform 

1559 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1560 

1561 Returns: 

1562 Transformed data 

1563 """ 

1564 # TODO: clean up and try putting more of this in the 1D bases 

1565 trfs = { 

1566 ChebychevHelper: self._transform_dct, 

1567 UltrasphericalHelper: self._transform_dct, 

1568 FFTHelper: self._transform_fft, 

1569 } 

1570 

1571 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1572 padding = ( 

1573 [ 

1574 1, 

1575 ] 

1576 * self.ndim 

1577 if padding is None 

1578 else padding 

1579 ) # You know, sometimes I feel very strongly about Black still. This atrocious formatting is readable by Sauron only. 

1580 

1581 result = u.copy().astype(complex) 

1582 alignment = self.ndim - 1 

1583 

1584 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1585 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1586 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1587 shape = [max(u.shape[i], self.global_shape[1 + i]) for i in range(self.ndim)] 

1588 

1589 fft = self.get_fft(axes=axes, padding=padding, direction='object') 

1590 if fft is not None: 

1591 shape = list(fft.global_shape(False)) 

1592 

1593 for trf in range(len(axes_collapsed)): 

1594 _axes = axes_collapsed[trf] 

1595 base = bases[trf] 

1596 

1597 if len(_axes) == 0: 

1598 continue 

1599 

1600 for _ax in _axes: 

1601 shape[_ax] = self.global_shape[1 + self.ndim + _ax] 

1602 

1603 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1604 

1605 _in = self.get_aligned( 

1606 result, axis_in=alignment, axis_out=self.ndim + _axes[-1], forward=False, fft=fft, shape=shape 

1607 ) 

1608 

1609 alignment = self.ndim + _axes[-1] 

1610 

1611 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1612 

1613 if self.comm is not None: 

1614 _out *= np.prod([self.axes[i].N for i in _axes]) 

1615 

1616 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1617 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[-1] 

1618 result = self.get_aligned( 

1619 _out, axis_in=self.ndim + _axes[0], axis_out=alignment, fft=fft, forward=True, shape=shape 

1620 ) 

1621 

1622 return result 

1623 

1624 def transform(self, u, axes=None, padding=None): 

1625 """ 

1626 Transform all components from physical space to spectral space 

1627 

1628 Args: 

1629 u data to transform: 

1630 axes (tuple): Axes over which to transform 

1631 padding (list): Padding factor for transform. E.g. a padding factor of 2 will discard the upper half of modes after transforming 

1632 

1633 Returns: 

1634 Transformed data 

1635 """ 

1636 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1637 padding = ( 

1638 [ 

1639 1, 

1640 ] 

1641 * self.ndim 

1642 if padding is None 

1643 else padding 

1644 ) 

1645 

1646 result = [ 

1647 None, 

1648 ] * self.ncomponents 

1649 for comp in self.components: 

1650 i = self.index(comp) 

1651 

1652 result[i] = self.transform_single_component(u[i], axes=axes, padding=padding) 

1653 

1654 return self.xp.stack(result) 

1655 

1656 def _transform_ifft(self, u, axes, **kwargs): 

1657 # TODO: clean up and try putting more of this in the 1D bases 

1658 ifft = self.get_fft(axes, 'backward', **kwargs) 

1659 return ifft(u, axes=axes) 

1660 

1661 def _transform_idct(self, u, axes, padding=None, **kwargs): 

1662 ''' 

1663 This will only ever return real values! 

1664 ''' 

1665 # TODO: clean up and try putting more of this in the 1D bases 

1666 if self.debug: 

1667 assert self.xp.allclose(u.imag, 0), 'This function can only handle real input.' 

1668 

1669 v = u.copy().astype(complex) 

1670 

1671 if len(axes) > 1: 

1672 v = self._transform_idct(self._transform_idct(u, axes[1:]), (axes[0],)) 

1673 else: 

1674 axis = axes[0] 

1675 base = self.axes[axis] 

1676 

1677 if padding is not None: 

1678 if padding[axis] != 1: 

1679 N_pad = int(np.ceil(v.shape[axis] * padding[axis])) 

1680 _pad = [[0, 0] for _ in v.shape] 

1681 _pad[axis] = [0, N_pad - base.N] 

1682 v = self.xp.pad(v, _pad, 'constant') 

1683 

1684 shift = self.xp.exp(1j * np.pi * self.xp.arange(N_pad) / (2 * N_pad)) * base.N 

1685 else: 

1686 shift = base.fft_utils['bck']['shift'] 

1687 else: 

1688 shift = base.fft_utils['bck']['shift'] 

1689 

1690 expansion = [np.newaxis for _ in u.shape] 

1691 expansion[axis] = slice(0, v.shape[axis], 1) 

1692 

1693 v *= shift[(*expansion,)] 

1694 

1695 if padding is not None: 

1696 if padding[axis] != 1: 

1697 shape = list(v.shape) 

1698 if ('backward', *padding) in self.fft_dealias_shape_cache.keys(): 

1699 shape[0] = self.fft_dealias_shape_cache[('backward', *padding)] 

1700 elif self.comm: 

1701 send_buf = np.array(v.shape[0]) 

1702 recv_buf = np.array(v.shape[0]) 

1703 self.comm.Allreduce(send_buf, recv_buf) 

1704 shape[0] = int(recv_buf) 

1705 ifft = self.get_fft(axes, 'backward', shape=shape) 

1706 else: 

1707 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1708 else: 

1709 ifft = self.get_fft(axes, 'backward', padding=padding, **kwargs) 

1710 v = ifft(v, axes=axes) 

1711 

1712 shuffle = [slice(0, s, 1) for s in v.shape] 

1713 shuffle[axis] = base.get_fft_shuffle(False, N=v.shape[axis]) 

1714 v = v[(*shuffle,)] 

1715 

1716 return v.real 

1717 

1718 def itransform_single_component(self, u, axes=None, padding=None): 

1719 """ 

1720 Inverse transform over single component of the solution 

1721 

1722 Args: 

1723 u data to transform: 

1724 axes (tuple): Axes over which to transform 

1725 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1726 

1727 Returns: 

1728 Transformed data 

1729 """ 

1730 # TODO: clean up and try putting more of this in the 1D bases 

1731 trfs = { 

1732 FFTHelper: self._transform_ifft, 

1733 ChebychevHelper: self._transform_idct, 

1734 UltrasphericalHelper: self._transform_idct, 

1735 } 

1736 

1737 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1738 padding = ( 

1739 [ 

1740 1, 

1741 ] 

1742 * self.ndim 

1743 if padding is None 

1744 else padding 

1745 ) 

1746 

1747 result = u.copy().astype(complex) 

1748 alignment = self.ndim - 1 

1749 

1750 axes_collapsed = [tuple(sorted(me for me in axes if type(self.axes[me]) == base)) for base in trfs.keys()] 

1751 bases = [list(trfs.keys())[i] for i in range(len(axes_collapsed)) if len(axes_collapsed[i]) > 0] 

1752 axes_collapsed = [me for me in axes_collapsed if len(me) > 0] 

1753 shape = list(self.global_shape[1:]) 

1754 

1755 for trf in range(len(axes_collapsed)): 

1756 _axes = axes_collapsed[trf] 

1757 base = bases[trf] 

1758 

1759 if len(_axes) == 0: 

1760 continue 

1761 

1762 fft = self.get_fft(_axes, 'object', padding=padding, shape=shape) 

1763 

1764 _in = self.get_aligned( 

1765 result, axis_in=alignment, axis_out=self.ndim + _axes[0], forward=True, fft=fft, shape=shape 

1766 ) 

1767 if self.comm is not None: 

1768 _in /= np.prod([self.axes[i].N for i in _axes]) 

1769 

1770 alignment = self.ndim + _axes[0] 

1771 

1772 _out = trfs[base](_in, axes=_axes, padding=padding, shape=shape) 

1773 

1774 for _ax in _axes: 

1775 if fft: 

1776 shape[_ax] = fft._input_shape[_ax] 

1777 else: 

1778 shape[_ax] = _out.shape[_ax] 

1779 

1780 axes_next_base = (axes_collapsed + [(-1,)])[trf + 1] 

1781 alignment = alignment if len(axes_next_base) == 0 else self.ndim + axes_next_base[0] 

1782 result = self.get_aligned( 

1783 _out, axis_in=self.ndim + _axes[-1], axis_out=alignment, fft=fft, forward=False, shape=shape 

1784 ) 

1785 

1786 return result 

1787 

1788 def get_aligned(self, u, axis_in, axis_out, fft=None, forward=False, **kwargs): 

1789 """ 

1790 Realign the data along the axis when using distributed FFTs. `kwargs` will be used to get the correct PFFT 

1791 object from `mpi4py-fft`, which has suitable transfer classes for the shape of data. Hence, they should include 

1792 shape especially, if applicable. 

1793 

1794 Args: 

1795 u: The solution 

1796 axis_in (int): Current alignment 

1797 axis_out (int): New alignment 

1798 fft (mpi4py_fft.PFFT), optional: parallel FFT object 

1799 forward (bool): Whether the input is in spectral space or not 

1800 

1801 Returns: 

1802 solution aligned on `axis_in` 

1803 """ 

1804 if self.comm is None or axis_in == axis_out: 

1805 return u.copy() 

1806 if self.comm.size == 1: 

1807 return u.copy() 

1808 

1809 global_fft = self.get_fft(**kwargs) 

1810 axisA = [me.axisA for me in global_fft.transfer] 

1811 axisB = [me.axisB for me in global_fft.transfer] 

1812 

1813 current_axis = axis_in 

1814 

1815 if axis_in in axisA and axis_out in axisB: 

1816 while current_axis != axis_out: 

1817 transfer = global_fft.transfer[axisA.index(current_axis)] 

1818 

1819 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1820 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1821 arrayA[:] = u[:] 

1822 

1823 transfer.forward(arrayA=arrayA, arrayB=arrayB) 

1824 

1825 current_axis = transfer.axisB 

1826 u = arrayB 

1827 

1828 return u 

1829 elif axis_in in axisB and axis_out in axisA: 

1830 while current_axis != axis_out: 

1831 transfer = global_fft.transfer[axisB.index(current_axis)] 

1832 

1833 arrayB = self.xp.empty(shape=transfer.subshapeB, dtype=transfer.dtype) 

1834 arrayA = self.xp.empty(shape=transfer.subshapeA, dtype=transfer.dtype) 

1835 arrayB[:] = u[:] 

1836 

1837 transfer.backward(arrayA=arrayA, arrayB=arrayB) 

1838 

1839 current_axis = transfer.axisA 

1840 u = arrayA 

1841 

1842 return u 

1843 else: # go the potentially slower route of not reusing transfer classes 

1844 from mpi4py_fft import newDistArray 

1845 

1846 fft = self.get_fft(**kwargs) if fft is None else fft 

1847 

1848 _in = newDistArray(fft, forward).redistribute(axis_in) 

1849 _in[...] = u 

1850 

1851 return _in.redistribute(axis_out) 

1852 

1853 def itransform(self, u, axes=None, padding=None): 

1854 """ 

1855 Inverse transform over all components of the solution 

1856 

1857 Args: 

1858 u data to transform: 

1859 axes (tuple): Axes over which to transform 

1860 padding (list): Padding factor for transform. E.g. a padding factor of 2 will add as many zeros as there were modes before before transforming 

1861 

1862 Returns: 

1863 Transformed data 

1864 """ 

1865 axes = tuple(-i - 1 for i in range(self.ndim)[::-1]) if axes is None else axes 

1866 padding = ( 

1867 [ 

1868 1, 

1869 ] 

1870 * self.ndim 

1871 if padding is None 

1872 else padding 

1873 ) 

1874 

1875 result = [ 

1876 None, 

1877 ] * self.ncomponents 

1878 for comp in self.components: 

1879 i = self.index(comp) 

1880 

1881 result[i] = self.itransform_single_component(u[i], axes=axes, padding=padding) 

1882 

1883 return self.xp.stack(result) 

1884 

1885 def get_local_slice_of_1D_matrix(self, M, axis): 

1886 """ 

1887 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1888 which you can sort out via the `SpectralHelper.local_slice` attribute. When constructing a 1D matrix, you can 

1889 use this method to get the part corresponding to the modes carried by this rank. 

1890 

1891 Args: 

1892 M (sparse matrix): Global 1D matrix you want to get the local version of 

1893 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. This means slab decomposition only. 

1894 

1895 Returns: 

1896 sparse local matrix 

1897 """ 

1898 return M.tocsc()[self.local_slice[axis], self.local_slice[axis]] 

1899 

1900 def expand_matrix_ND(self, matrix, aligned): 

1901 sp = self.sparse_lib 

1902 axes = np.delete(np.arange(self.ndim), aligned) 

1903 ndim = len(axes) + 1 

1904 

1905 if ndim == 1: 

1906 return matrix 

1907 elif ndim == 2: 

1908 axis = axes[0] 

1909 I1D = sp.eye(self.axes[axis].N) 

1910 

1911 mats = [None] * ndim 

1912 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1913 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1914 

1915 return sp.kron(*mats) 

1916 

1917 else: 

1918 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') 

1919 

1920 def get_filter_matrix(self, axis, **kwargs): 

1921 """ 

1922 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

1923 admissible. 

1924 

1925 Returns: 

1926 sparse bandpass matrix 

1927 """ 

1928 if self.ndim == 1: 

1929 return self.axes[0].get_filter_matrix(**kwargs) 

1930 

1931 mats = [base.get_Id() for base in self.axes] 

1932 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

1933 return self.sparse_lib.kron(*mats) 

1934 

1935 def get_differentiation_matrix(self, axes, **kwargs): 

1936 """ 

1937 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

1938 

1939 Args: 

1940 axes (tuple): Axes along which to differentiate. 

1941 

1942 Returns: 

1943 sparse differentiation matrix 

1944 """ 

1945 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) 

1946 for axis in axes[1:]: 

1947 _D = self.axes[axis].get_differentiation_matrix(**kwargs) 

1948 D = D @ self.expand_matrix_ND(_D, axis) 

1949 

1950 return D 

1951 

1952 def get_integration_matrix(self, axes): 

1953 """ 

1954 Get integration matrix to integrate along specified axis. 

1955 

1956 Args: 

1957 axes (tuple): Axes along which to integrate over. 

1958 

1959 Returns: 

1960 sparse integration matrix 

1961 """ 

1962 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) 

1963 for axis in axes[1:]: 

1964 _S = self.axes[axis].get_integration_matrix() 

1965 S = S @ self.expand_matrix_ND(_S, axis) 

1966 

1967 return S 

1968 

1969 def get_Id(self): 

1970 """ 

1971 Get identity matrix 

1972 

1973 Returns: 

1974 sparse identity matrix 

1975 """ 

1976 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) 

1977 for axis in range(1, self.ndim): 

1978 _I = self.axes[axis].get_Id() 

1979 I = I @ self.expand_matrix_ND(_I, axis) 

1980 return I 

1981 

1982 def get_Dirichlet_recombination_matrix(self, axis=-1): 

1983 """ 

1984 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

1985 

1986 Args: 

1987 axis (int): Axis you discretized with Chebychev 

1988 

1989 Returns: 

1990 sparse matrix 

1991 """ 

1992 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

1993 return self.expand_matrix_ND(C1D, axis) 

1994 

1995 def get_basis_change_matrix(self, axes=None, **kwargs): 

1996 """ 

1997 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

1998 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

1999 

2000 Args: 

2001 axes (tuple): Axes along which to change basis. 

2002 

2003 Returns: 

2004 sparse basis change matrix 

2005 """ 

2006 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

2007 

2008 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) 

2009 for axis in axes[1:]: 

2010 _C = self.axes[axis].get_basis_change_matrix(**kwargs) 

2011 C = C @ self.expand_matrix_ND(_C, axis) 

2012 

2013 return C