Coverage for pySDC/helpers/spectral_helper.py: 83%

865 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-12 05:46 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5from functools import partial, wraps 

6import logging 

7 

8 

9def cache(func): 

10 """ 

11 Decorator for caching return values of functions. 

12 This is very similar to `functools.cache`, but without the memory leaks (see 

13 https://docs.astral.sh/ruff/rules/cached-instance-method/). 

14 

15 Example: 

16 

17 .. code-block:: python 

18 

19 num_calls = 0 

20 

21 @cache 

22 def increment(x): 

23 num_calls += 1 

24 return x + 1 

25 

26 increment(0) # returns 1, num_calls = 1 

27 increment(1) # returns 2, num_calls = 2 

28 increment(0) # returns 1, num_calls = 2 

29 

30 

31 Args: 

32 func (function): The function you want to cache the return value of 

33 

34 Returns: 

35 return value of func 

36 """ 

37 attr_cache = f"_{func.__name__}_cache" 

38 

39 @wraps(func) 

40 def wrapper(self, *args, **kwargs): 

41 if not hasattr(self, attr_cache): 

42 setattr(self, attr_cache, {}) 

43 

44 cache = getattr(self, attr_cache) 

45 

46 key = (args, frozenset(kwargs.items())) 

47 if key in cache: 

48 return cache[key] 

49 result = func(self, *args, **kwargs) 

50 cache[key] = result 

51 return result 

52 

53 return wrapper 

54 

55 

56class vkFFT(object): 

57 """ 

58 pyVkFFT FFT backend. 

59 The special feature of vkFFT is fast DCT on GPU with cached plans. 

60 """ 

61 

62 cached_plans = {} 

63 

64 @staticmethod 

65 def is_complex(x): 

66 return 'complex' in str(x.dtype) 

67 

68 @staticmethod 

69 def get_plan(transform_type, shape, dtype, axes, norm): 

70 from pyvkfft.cuda import VkFFTApp 

71 

72 assert norm == 'backward' 

73 

74 key = f'{transform_type=}, {shape=}, {dtype=}, {axes=}, {norm=}' 

75 

76 if key not in vkFFT.cached_plans.keys(): 

77 

78 kwargs = {} 

79 

80 if transform_type == 'dct': 

81 kwargs['dct'] = 2 

82 

83 vkFFT.cached_plans[key] = VkFFTApp(shape, dtype, len(axes), axes=axes, norm=norm, **kwargs) 

84 

85 logger = logging.getLogger(name='VkFFT') 

86 logger.debug(f'Cached plan for VkFFT: {key}') 

87 return vkFFT.cached_plans[key] 

88 

89 @staticmethod 

90 def fftn(x, s=None, axes=None, norm='backward', overwrite_x=False): 

91 assert not overwrite_x # for consistent interface with scipy 

92 assert norm == 'backward' # for consistent interface with scipy 

93 plan = vkFFT.get_plan( 

94 transform_type='fft', 

95 shape=x.shape, 

96 dtype=x.dtype, 

97 axes=axes, 

98 norm=norm, 

99 ) 

100 _x = x.copy() + 0j # cast to complex 

101 plan.fft(_x) 

102 return _x 

103 

104 @staticmethod 

105 def ifftn(x, s=None, axes=None, norm='forward', overwrite_x=False): 

106 assert norm == 'forward' 

107 assert not overwrite_x # for consistent interface with scipy 

108 

109 norm = 'backward' 

110 plan = vkFFT.get_plan( 

111 transform_type='fft', 

112 shape=x.shape, 

113 dtype=x.dtype, 

114 axes=axes, 

115 norm=norm, 

116 ) 

117 _x = x.copy() + 0j # promote to complex 

118 plan.ifft(_x) 

119 return _x * sum(x.shape[i] for i in axes) 

120 

121 @staticmethod 

122 def dctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False): 

123 assert type == 2 # for consistent interface with scipy 

124 assert not overwrite_x # for consistent interface with scipy 

125 

126 is_complex = vkFFT.is_complex(x) 

127 

128 dtype = x.dtype if not is_complex else x.real.dtype 

129 

130 plan = vkFFT.get_plan( 

131 transform_type='dct', 

132 shape=x.shape, 

133 dtype=dtype, 

134 axes=axes, 

135 norm=norm, 

136 ) 

137 

138 if is_complex: 

139 x_real = x.real.copy() 

140 x_imag = x.imag.copy() 

141 

142 plan.fft(x_real) 

143 plan.fft(x_imag) 

144 

145 return x_real + 1j * x_imag 

146 else: 

147 _x = x.copy() 

148 plan.fft(x) 

149 return x 

150 

151 @staticmethod 

152 def idctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False): 

153 assert type == 2 # for consistent interface with scipy 

154 assert not overwrite_x # for consistent interface with scipy 

155 

156 is_complex = vkFFT.is_complex(x) 

157 dtype = x.dtype if not is_complex else x.real.dtype 

158 

159 plan = vkFFT.get_plan( 

160 transform_type='dct', 

161 shape=x.shape, 

162 dtype=dtype, 

163 axes=axes, 

164 norm=norm, 

165 ) 

166 

167 if is_complex: 

168 x_real = x.real.copy() 

169 x_imag = x.imag.copy() 

170 

171 plan.ifft(x_real) 

172 plan.ifft(x_imag) 

173 

174 return x_real + 1j * x_imag 

175 else: 

176 _x = x.copy() 

177 plan.ifft(x) 

178 return x 

179 

180 

181class SpectralHelper1D: 

182 """ 

183 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

184 all bases need to have. 

185 

186 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

187 the code for GPUs. 

188 

189 Attributes: 

190 N (int): Resolution 

191 x0 (float): Coordinate of left boundary 

192 x1 (float): Coordinate of right boundary 

193 L (float): Length of the domain 

194 useGPU (bool): Whether to use GPUs 

195 

196 """ 

197 

198 fft_lib = scipy.fft 

199 sparse_lib = scipy.sparse 

200 linalg = scipy.sparse.linalg 

201 xp = np 

202 distributable = False 

203 

204 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False): 

205 """ 

206 Constructor 

207 

208 Args: 

209 N (int): Resolution 

210 x0 (float): Coordinate of left boundary 

211 x1 (float): Coordinate of right boundary 

212 useGPU (bool): Whether to use GPUs 

213 useFFTW (bool): Whether to use FFTW for the transforms 

214 """ 

215 self.N = N 

216 self.x0 = x0 

217 self.x1 = x1 

218 self.L = x1 - x0 

219 self.useGPU = useGPU 

220 self.plans = {} 

221 self.logger = logging.getLogger(name=type(self).__name__) 

222 

223 if useGPU: 

224 self.setup_GPU() 

225 self.logger.debug('Set up for GPU') 

226 else: 

227 self.setup_CPU(useFFTW=useFFTW) 

228 

229 if useGPU and useFFTW: 

230 raise ValueError('Please run either on GPUs or with FFTW, not both!') 

231 

232 @classmethod 

233 def setup_GPU(cls): 

234 """switch to GPU modules""" 

235 import cupy as cp 

236 import cupyx.scipy.sparse as sparse_lib 

237 import cupyx.scipy.sparse.linalg as linalg 

238 import cupyx.scipy.fft as fft_lib 

239 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

240 

241 cls.xp = cp 

242 cls.sparse_lib = sparse_lib 

243 cls.linalg = linalg 

244 cls.fft_lib = fft_lib 

245 

246 @classmethod 

247 def setup_CPU(cls, useFFTW=False): 

248 """switch to CPU modules""" 

249 

250 cls.xp = np 

251 cls.sparse_lib = scipy.sparse 

252 cls.linalg = scipy.sparse.linalg 

253 

254 if useFFTW: 

255 from mpi4py_fft import fftw 

256 

257 cls.fft_backend = 'fftw' 

258 cls.fft_lib = fftw 

259 else: 

260 cls.fft_backend = 'scipy' 

261 cls.fft_lib = scipy.fft 

262 

263 cls.fft_comm_backend = 'MPI' 

264 cls.dtype = mesh 

265 

266 def get_Id(self): 

267 """ 

268 Get identity matrix 

269 

270 Returns: 

271 sparse diagonal identity matrix 

272 """ 

273 return self.sparse_lib.eye(self.N) 

274 

275 def get_zero(self): 

276 """ 

277 Get a matrix with all zeros of the correct size. 

278 

279 Returns: 

280 sparse matrix with zeros everywhere 

281 """ 

282 return 0 * self.get_Id() 

283 

284 def get_differentiation_matrix(self): 

285 raise NotImplementedError() 

286 

287 def get_integration_matrix(self): 

288 raise NotImplementedError() 

289 

290 def get_integration_weights(self): 

291 """Weights for integration across entire domain""" 

292 raise NotImplementedError() 

293 

294 def get_wavenumbers(self): 

295 """ 

296 Get the grid in spectral space 

297 """ 

298 raise NotImplementedError 

299 

300 def get_empty_operator_matrix(self, S, O): 

301 """ 

302 Return a matrix of operators to be filled with the connections between the solution components. 

303 

304 Args: 

305 S (int): Number of components in the solution 

306 O (sparse matrix): Zero matrix used for initialization 

307 

308 Returns: 

309 list of lists containing sparse zeros 

310 """ 

311 return [[O for _ in range(S)] for _ in range(S)] 

312 

313 def get_basis_change_matrix(self, *args, **kwargs): 

314 """ 

315 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

316 between the various bases. 

317 

318 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

319 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

320 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

321 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

322 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

323 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

324 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

325 

326 Returns: 

327 sparse bases change matrix 

328 """ 

329 return self.sparse_lib.eye(self.N) 

330 

331 def get_BC(self, kind): 

332 """ 

333 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

334 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

335 the matrix are replaced by the boundary conditions as obtained by this method. 

336 

337 Args: 

338 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

339 individual 1D bases for what is implemented 

340 

341 Returns: 

342 self.xp.array: Boundary condition 

343 """ 

344 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

345 

346 def get_filter_matrix(self, kmin=0, kmax=None): 

347 """ 

348 Get a bandpass filter. 

349 

350 Args: 

351 kmin (int): Lower limit of the bandpass filter 

352 kmax (int): Upper limit of the bandpass filter 

353 

354 Returns: 

355 sparse matrix 

356 """ 

357 

358 k = abs(self.get_wavenumbers()) 

359 

360 kmax = max(k) if kmax is None else kmax 

361 

362 mask = self.xp.logical_or(k >= kmax, k < kmin) 

363 

364 if self.useGPU: 

365 Id = self.get_Id().get() 

366 else: 

367 Id = self.get_Id() 

368 F = Id.tolil() 

369 F[:, mask] = 0 

370 return F.tocsc() 

371 

372 def get_1dgrid(self): 

373 """ 

374 Get the grid in physical space 

375 

376 Returns: 

377 self.xp.array: Grid 

378 """ 

379 raise NotImplementedError 

380 

381 

382class ChebychevHelper(SpectralHelper1D): 

383 """ 

384 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

385 between physical and spectral space by discrete cosine transform. 

386 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

387 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

388 be formulated in first order formulation. 

389 

390 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068). 

391 """ 

392 

393 def __init__(self, *args, x0=-1, x1=1, **kwargs): 

394 """ 

395 Constructor. 

396 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

397 may choose to run on GPUs via the `useGPU` argument. 

398 

399 Args: 

400 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

401 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

402 """ 

403 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

404 self.lin_trf_fac = (x1 - x0) / 2 

405 self.lin_trf_off = (x1 + x0) / 2 

406 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

407 

408 self.norm = self.get_norm() 

409 

410 def get_1dgrid(self): 

411 ''' 

412 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

413 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

414 need to do an affine transformation before any Chebychev business. 

415 

416 Returns: 

417 numpy.ndarray: 1D grid 

418 ''' 

419 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

420 

421 def get_wavenumbers(self): 

422 """Get the domain in spectral space""" 

423 return self.xp.arange(self.N) 

424 

425 @cache 

426 def get_conv(self, name, N=None): 

427 ''' 

428 Get conversion matrix between different kinds of polynomials. The supported kinds are 

429 - T: Chebychev polynomials of first kind 

430 - U: Chebychev polynomials of second kind 

431 - D: Dirichlet recombination. 

432 

433 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

434 Once generates matrices are cached. So feel free to call the method as often as you like. 

435 

436 Args: 

437 name (str): Conversion code, e.g. 'T2U' 

438 N (int): Size of the matrix (optional) 

439 

440 Returns: 

441 scipy.sparse: Sparse conversion matrix 

442 ''' 

443 N = N if N else self.N 

444 sp = self.sparse_lib 

445 

446 def get_forward_conv(name): 

447 if name == 'T2U': 

448 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0 

449 mat[:, 0] *= 2 

450 elif name == 'D2T': 

451 mat = sp.eye(N) - sp.eye(N, k=2) 

452 elif name[0] == name[-1]: 

453 mat = self.sparse_lib.eye(self.N) 

454 else: 

455 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

456 return mat 

457 

458 try: 

459 mat = get_forward_conv(name) 

460 except NotImplementedError as E: 

461 try: 

462 fwd = get_forward_conv(name[::-1]) 

463 import scipy.sparse as sp 

464 

465 if self.sparse_lib == sp: 

466 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

467 else: 

468 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

469 except NotImplementedError: 

470 raise NotImplementedError from E 

471 

472 return mat 

473 

474 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

475 """ 

476 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

477 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

478 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

479 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

480 

481 Args: 

482 conv (str): Conversion code, i.e. T2U 

483 

484 Returns: 

485 Sparse conversion matrix 

486 """ 

487 return self.get_conv(conv) 

488 

489 def get_integration_matrix(self, lbnd=0): 

490 """ 

491 Get matrix for integration 

492 

493 Args: 

494 lbnd (float): Lower bound for integration, only 0 is currently implemented 

495 

496 Returns: 

497 Sparse integration matrix 

498 """ 

499 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

500 n = self.xp.arange(self.N) 

501 if lbnd == 0: 

502 S = S.tocsc() 

503 S[0, 1::2] = ( 

504 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

505 * (-1) ** (self.xp.arange(self.N // 2)) 

506 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

507 ) * self.lin_trf_fac 

508 else: 

509 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

510 return S 

511 

512 def get_integration_weights(self): 

513 """Weights for integration across entire domain""" 

514 n = self.xp.arange(self.N, dtype=float) 

515 

516 weights = (-1) ** n + 1 

517 weights[2:] /= 1 - (n**2)[2:] 

518 

519 weights /= 2 / self.L 

520 return weights 

521 

522 def get_differentiation_matrix(self, p=1): 

523 ''' 

524 Keep in mind that the T2T differentiation matrix is dense. 

525 

526 Args: 

527 p (int): Derivative you want to compute 

528 

529 Returns: 

530 numpy.ndarray: Differentiation matrix 

531 ''' 

532 D = self.xp.zeros((self.N, self.N)) 

533 for j in range(self.N): 

534 for k in range(j): 

535 D[k, j] = 2 * j * ((j - k) % 2) 

536 

537 D[0, :] /= 2 

538 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

539 

540 @cache 

541 def get_norm(self, N=None): 

542 ''' 

543 Get normalization for converting Chebychev coefficients and DCT 

544 

545 Args: 

546 N (int, optional): Resolution 

547 

548 Returns: 

549 self.xp.array: Normalization 

550 ''' 

551 N = self.N if N is None else N 

552 norm = self.xp.ones(N) / N 

553 norm[0] /= 2 

554 return norm 

555 

556 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

557 """ 

558 DCT along axes. `kwargs` will be passed on to the FFT library. 

559 

560 Args: 

561 u: Data you want to transform 

562 axes (tuple): Axes you want to transform along 

563 

564 Returns: 

565 Data in spectral space 

566 """ 

567 axes = axes if axes else tuple(i for i in range(u.ndim)) 

568 kwargs['s'] = shape 

569 kwargs['norm'] = kwargs.get('norm', 'backward') 

570 

571 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs) 

572 for axis in axes: 

573 

574 if self.N < trf.shape[axis]: 

575 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

576 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies 

577 # in the middle. 

578 _trf = self.xp.zeros_like(trf) 

579 N = self.N 

580 N_pad = _trf.shape[axis] - N 

581 end_first_half = N // 2 + 1 

582 

583 # copy first "half" 

584 su = [slice(None)] * trf.ndim 

585 su[axis] = slice(0, end_first_half) 

586 _trf[tuple(su)] = trf[tuple(su)] 

587 

588 # copy second "half" 

589 su = [slice(None)] * u.ndim 

590 su[axis] = slice(end_first_half + N_pad, None) 

591 s_u = [slice(None)] * u.ndim 

592 s_u[axis] = slice(end_first_half, N) 

593 _trf[tuple(su)] = trf[tuple(s_u)] 

594 

595 # # copy values to be cut 

596 # su = [slice(None)] * u.ndim 

597 # su[axis] = slice(end_first_half, end_first_half + N_pad) 

598 # s_u = [slice(None)] * u.ndim 

599 # s_u[axis] = slice(-N_pad, None) 

600 # _trf[tuple(su)] = trf[tuple(s_u)] 

601 

602 trf = _trf 

603 

604 expansion = [np.newaxis for _ in u.shape] 

605 expansion[axis] = slice(0, u.shape[axis], 1) 

606 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1] 

607 norm[: self.N] = self.norm 

608 trf *= norm[(*expansion,)] 

609 return trf 

610 

611 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

612 """ 

613 Inverse DCT along axis. 

614 

615 Args: 

616 u: Data you want to transform 

617 axes (tuple): Axes you want to transform along 

618 

619 Returns: 

620 Data in physical space 

621 """ 

622 axes = axes if axes else tuple(i for i in range(u.ndim)) 

623 kwargs['s'] = shape 

624 kwargs['norm'] = kwargs.get('norm', 'backward') 

625 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False) 

626 

627 for axis in axes: 

628 

629 if self.N == u.shape[axis]: 

630 _u = u.copy() 

631 else: 

632 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

633 # added in the middle rather than the end. We need to resort this here and put the padding in the end. 

634 N = self.N 

635 _u = self.xp.zeros_like(u) 

636 

637 # copy first half 

638 su = [slice(None)] * u.ndim 

639 su[axis] = slice(0, N // 2 + 1) 

640 _u[tuple(su)] = u[tuple(su)] 

641 

642 # copy second half 

643 su = [slice(None)] * u.ndim 

644 su[axis] = slice(-(N // 2), None) 

645 s_u = [slice(None)] * u.ndim 

646 s_u[axis] = slice(N // 2, N // 2 + (N // 2)) 

647 _u[tuple(s_u)] = u[tuple(su)] 

648 

649 if N % 2 == 0: 

650 su = [slice(None)] * u.ndim 

651 su[axis] = N // 2 

652 _u[tuple(su)] *= 2 

653 

654 # generate norm 

655 expansion = [np.newaxis for _ in u.shape] 

656 expansion[axis] = slice(0, u.shape[axis], 1) 

657 norm = self.xp.ones(_u.shape[axis]) 

658 norm[: self.N] = self.norm 

659 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N 

660 

661 _u /= norm[(*expansion,)] 

662 

663 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs) 

664 

665 def get_BC(self, kind, **kwargs): 

666 """ 

667 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

668 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

669 set the BC. 

670 

671 Args: 

672 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

673 """ 

674 if kind.lower() == 'integral': 

675 return self.get_integ_BC_row(**kwargs) 

676 elif kind.lower() == 'dirichlet': 

677 return self.get_Dirichlet_BC_row(**kwargs) 

678 elif kind.lower() == 'neumann': 

679 return self.get_Neumann_BC_row(**kwargs) 

680 else: 

681 return super().get_BC(kind) 

682 

683 def get_integ_BC_row(self): 

684 """ 

685 Get a row for generating integral BCs with T polynomials. 

686 It returns the values of the integrals of T polynomials over the entire interval. 

687 

688 Returns: 

689 self.xp.ndarray: Row to put into a matrix 

690 """ 

691 n = self.xp.arange(self.N) + 1 

692 me = self.xp.zeros_like(n).astype(float) 

693 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

694 me[0] = 2.0 

695 return me 

696 

697 def get_Dirichlet_BC_row(self, x): 

698 """ 

699 Get a row for generating Dirichlet BCs at x with T polynomials. 

700 It returns the values of the T polynomials at x. 

701 

702 Args: 

703 x (float): Position of the boundary condition 

704 

705 Returns: 

706 self.xp.ndarray: Row to put into a matrix 

707 """ 

708 if x == -1: 

709 return (-1) ** self.xp.arange(self.N) 

710 elif x == 1: 

711 return self.xp.ones(self.N) 

712 elif x == 0: 

713 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

714 n[2::4] *= -1 

715 return n 

716 else: 

717 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

718 

719 def get_Neumann_BC_row(self, x): 

720 """ 

721 Get a row for generating Neumann BCs at x with T polynomials. 

722 

723 Args: 

724 x (float): Position of the boundary condition 

725 

726 Returns: 

727 self.xp.ndarray: Row to put into a matrix 

728 """ 

729 n = self.xp.arange(self.N, dtype='D') 

730 nn = n**2 

731 if x == -1: 

732 me = nn 

733 me[1:] *= (-1) ** n[:-1] 

734 return me 

735 elif x == 1: 

736 return nn 

737 else: 

738 raise NotImplementedError(f'Don\'t know how to generate Neumann BC\'s at {x=}!') 

739 

740 def get_Dirichlet_recombination_matrix(self): 

741 ''' 

742 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

743 This makes for a good right preconditioner. 

744 

745 Returns: 

746 scipy.sparse: Sparse conversion matrix 

747 ''' 

748 N = self.N 

749 sp = self.sparse_lib 

750 

751 return sp.eye(N) - sp.eye(N, k=2) 

752 

753 

754class UltrasphericalHelper(ChebychevHelper): 

755 """ 

756 This implementation follows https://doi.org/10.1137/120865458. 

757 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

758 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

759 There are also conversion operators from one derivative basis to the next that are sparse. 

760 

761 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

762 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

763 """ 

764 

765 def get_differentiation_matrix(self, p=1): 

766 """ 

767 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

768 

769 Args: 

770 p (int): Order of the derivative 

771 

772 Returns: 

773 sparse differentiation matrix 

774 """ 

775 sp = self.sparse_lib 

776 xp = self.xp 

777 N = self.N 

778 l = p 

779 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

780 

781 def get_S(self, lmbda): 

782 """ 

783 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

784 https://doi.org/10.1137/120865458. 

785 

786 Args: 

787 lmbda (int): Ingoing derivative base 

788 

789 Returns: 

790 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

791 """ 

792 N = self.N 

793 

794 if lmbda == 0: 

795 sp = scipy.sparse 

796 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil() 

797 mat[:, 0] *= 2 

798 else: 

799 sp = self.sparse_lib 

800 xp = self.xp 

801 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

802 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

803 ) 

804 

805 return self.sparse_lib.csc_matrix(mat) 

806 

807 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

808 """ 

809 Get a conversion matrix from derivative base `p_in` to `p_out`. 

810 

811 Args: 

812 p_out (int): Resulting derivative base 

813 p_in (int): Ingoing derivative base 

814 """ 

815 mat_fwd = self.sparse_lib.eye(self.N) 

816 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

817 mat_fwd = self.get_S(i) @ mat_fwd 

818 

819 if p_out > p_in: 

820 return mat_fwd 

821 

822 else: 

823 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

824 import scipy.sparse as sp 

825 

826 if self.useGPU: 

827 mat_fwd = mat_fwd.get() 

828 

829 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

830 

831 return self.sparse_lib.csc_matrix(mat_bck) 

832 

833 def get_integration_matrix(self): 

834 """ 

835 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

836 integration constant such that integration starts from x=-1. 

837 

838 Example: 

839 

840 .. code-block:: python 

841 

842 import numpy as np 

843 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

844 

845 N = 4 

846 helper = UltrasphericalHelper(N) 

847 coeffs = np.random.random(N) 

848 coeffs[-1] = 0 

849 

850 poly = np.polynomial.Chebyshev(coeffs) 

851 

852 S = helper.get_integration_matrix() 

853 U_hat = S @ coeffs 

854 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

855 

856 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

857 

858 Returns: 

859 sparse integration matrix 

860 """ 

861 return ( 

862 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

863 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

864 * self.lin_trf_fac 

865 ) 

866 

867 def get_integration_constant(self, u_hat, axis): 

868 """ 

869 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

870 

871 Args: 

872 u_hat: Solution in spectral space 

873 axis: Axis you want to integrate over 

874 

875 Returns: 

876 Integration constant, has one less dimension than `u_hat` 

877 """ 

878 slices = [ 

879 None, 

880 ] * u_hat.ndim 

881 slices[axis] = slice(1, u_hat.shape[axis]) 

882 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

883 

884 

885class FFTHelper(SpectralHelper1D): 

886 distributable = True 

887 

888 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

889 """ 

890 Constructor. 

891 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

892 may choose to run on GPUs via the `useGPU` argument. 

893 

894 Args: 

895 x0 (float, optional): Coordinate of left boundary 

896 x1 (float, optional): Coordinate of right boundary 

897 """ 

898 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

899 

900 def get_1dgrid(self): 

901 """ 

902 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

903 """ 

904 dx = self.L / self.N 

905 return self.xp.arange(self.N) * dx + self.x0 

906 

907 def get_wavenumbers(self): 

908 """ 

909 Be careful that this ordering is very unintuitive. 

910 """ 

911 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

912 

913 def get_differentiation_matrix(self, p=1): 

914 """ 

915 This matrix is diagonal, allowing to invert concurrently. 

916 

917 Args: 

918 p (int): Order of the derivative 

919 

920 Returns: 

921 sparse differentiation matrix 

922 """ 

923 k = self.get_wavenumbers() 

924 

925 if self.useGPU: 

926 if p > 1: 

927 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

928 from scipy.sparse.linalg import matrix_power 

929 

930 D = self.sparse_lib.diags(1j * k).get() 

931 return self.sparse_lib.csc_matrix(matrix_power(D, p)) 

932 else: 

933 return self.sparse_lib.diags(1j * k) 

934 else: 

935 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

936 

937 def get_integration_matrix(self, p=1): 

938 """ 

939 Get integration matrix to compute `p`-th integral over the entire domain. 

940 

941 Args: 

942 p (int): Order of integral you want to compute 

943 

944 Returns: 

945 sparse integration matrix 

946 """ 

947 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

948 k[0] = 1j * self.L 

949 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

950 

951 def get_integration_weights(self): 

952 """Weights for integration across entire domain""" 

953 weights = self.xp.zeros(self.N) 

954 weights[0] = self.L / self.N 

955 return weights 

956 

957 def get_plan(self, u, forward, *args, **kwargs): 

958 if self.fft_lib.__name__ == 'mpi4py_fft.fftw': 

959 if 'axes' in kwargs.keys(): 

960 kwargs['axes'] = tuple(kwargs['axes']) 

961 key = (forward, u.shape, args, *(me for me in kwargs.values())) 

962 if key in self.plans.keys(): 

963 return self.plans[key] 

964 else: 

965 self.logger.debug(f'Generating FFT plan for {key=}') 

966 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs) 

967 self.plans[key] = transform 

968 

969 return self.plans[key] 

970 else: 

971 if forward: 

972 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward')) 

973 else: 

974 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward')) 

975 

976 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

977 """ 

978 FFT along axes. `kwargs` are passed on to the FFT library. 

979 

980 Args: 

981 u: Data you want to transform 

982 axes (tuple): Axes you want to transform over 

983 

984 Returns: 

985 transformed data 

986 """ 

987 axes = axes if axes else tuple(i for i in range(u.ndim)) 

988 kwargs['s'] = shape 

989 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs) 

990 return plan(u, *args, axes=axes, **kwargs) 

991 

992 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

993 """ 

994 Inverse FFT. 

995 

996 Args: 

997 u: Data you want to transform 

998 axes (tuple): Axes over which to transform 

999 

1000 Returns: 

1001 transformed data 

1002 """ 

1003 axes = axes if axes else tuple(i for i in range(u.ndim)) 

1004 kwargs['s'] = shape 

1005 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs) 

1006 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes]) 

1007 

1008 def get_BC(self, kind): 

1009 """ 

1010 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

1011 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

1012 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

1013 

1014 Args: 

1015 kind ('integral' or 'nyquist'): Kind of BC 

1016 

1017 Returns: 

1018 self.xp.ndarray: Boundary condition row 

1019 """ 

1020 if kind.lower() == 'integral': 

1021 return self.get_integ_BC_row() 

1022 elif kind.lower() == 'nyquist': 

1023 assert ( 

1024 self.N % 2 == 0 

1025 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

1026 BC = self.xp.zeros(self.N) 

1027 BC[self.get_Nyquist_mode_index()] = 1 

1028 return BC 

1029 else: 

1030 return super().get_BC(kind) 

1031 

1032 def get_Nyquist_mode_index(self): 

1033 """ 

1034 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

1035 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

1036 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

1037 after. 

1038 

1039 Returns: 

1040 int: Index of the Nyquist mode 

1041 """ 

1042 k = self.get_wavenumbers() 

1043 Nyquist_mode = min(k) 

1044 return self.xp.where(k == Nyquist_mode)[0][0] 

1045 

1046 def get_integ_BC_row(self): 

1047 """ 

1048 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

1049 """ 

1050 me = self.xp.zeros(self.N) 

1051 me[0] = self.L / self.N 

1052 return me 

1053 

1054 

1055class SpectralHelper: 

1056 """ 

1057 This class has three functions: 

1058 - Easily assemble matrices containing multiple equations 

1059 - Direct product of 1D bases to solve problems in more dimensions 

1060 - Distribute the FFTs to facilitate concurrency. 

1061 

1062 Attributes: 

1063 comm (mpi4py.Intracomm): MPI communicator 

1064 debug (bool): Perform additional checks at extra computational cost 

1065 useGPU (bool): Whether to use GPUs 

1066 axes (list): List of 1D bases 

1067 components (list): List of strings of the names of components in the equations 

1068 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

1069 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

1070 BCs (sparse matrix): Matrix containing only the BCs 

1071 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

1072 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

1073 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

1074 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

1075 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

1076 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

1077 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

1078 init (tuple): This is the same `init` that is used throughout the problem classes 

1079 init_forward (tuple): This is the equivalent of `init` in spectral space 

1080 """ 

1081 

1082 xp = np 

1083 fft_lib = scipy.fft 

1084 sparse_lib = scipy.sparse 

1085 linalg = scipy.sparse.linalg 

1086 dtype = mesh 

1087 fft_backend = 'scipy' 

1088 fft_comm_backend = 'MPI' 

1089 

1090 @classmethod 

1091 def setup_GPU(cls): 

1092 """switch to GPU modules""" 

1093 import cupy as cp 

1094 import cupyx.scipy.sparse as sparse_lib 

1095 import cupyx.scipy.sparse.linalg as linalg 

1096 import cupyx.scipy.fft as fft_lib 

1097 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

1098 

1099 cls.xp = cp 

1100 cls.sparse_lib = sparse_lib 

1101 cls.linalg = linalg 

1102 

1103 cls.fft_lib = fft_lib 

1104 cls.fft_backend = 'cupyx-scipy' 

1105 cls.fft_comm_backend = 'NCCL' 

1106 

1107 cls.dtype = cupy_mesh 

1108 

1109 @classmethod 

1110 def setup_CPU(cls, useFFTW=False): 

1111 """switch to CPU modules""" 

1112 

1113 cls.xp = np 

1114 cls.sparse_lib = scipy.sparse 

1115 cls.linalg = scipy.sparse.linalg 

1116 

1117 if useFFTW: 

1118 from mpi4py_fft import fftw 

1119 

1120 cls.fft_backend = 'fftw' 

1121 cls.fft_lib = fftw 

1122 else: 

1123 cls.fft_backend = 'scipy' 

1124 cls.fft_lib = scipy.fft 

1125 

1126 cls.fft_comm_backend = 'MPI' 

1127 cls.dtype = mesh 

1128 

1129 def __init__(self, comm=None, useGPU=False, debug=False): 

1130 """ 

1131 Constructor 

1132 

1133 Args: 

1134 comm (mpi4py.Intracomm): MPI communicator 

1135 useGPU (bool): Whether to use GPUs 

1136 debug (bool): Perform additional checks at extra computational cost 

1137 """ 

1138 self.comm = comm 

1139 self.debug = debug 

1140 self.useGPU = useGPU 

1141 

1142 if useGPU: 

1143 self.setup_GPU() 

1144 else: 

1145 self.setup_CPU() 

1146 

1147 self.axes = [] 

1148 self.components = [] 

1149 

1150 self.full_BCs = [] 

1151 self.BC_mat = None 

1152 self.BCs = None 

1153 

1154 self.fft_cache = {} 

1155 

1156 self.logger = logging.getLogger(name='Spectral Discretization') 

1157 if debug: 

1158 self.logger.setLevel(logging.DEBUG) 

1159 

1160 @property 

1161 def u_init(self): 

1162 """ 

1163 Get empty data container in physical space 

1164 """ 

1165 return self.dtype(self.init) 

1166 

1167 @property 

1168 def u_init_forward(self): 

1169 """ 

1170 Get empty data container in spectral space 

1171 """ 

1172 return self.dtype(self.init_forward) 

1173 

1174 @property 

1175 def u_init_physical(self): 

1176 """ 

1177 Get empty data container in physical space 

1178 """ 

1179 return self.dtype(self.init_physical) 

1180 

1181 @property 

1182 def shape(self): 

1183 """ 

1184 Get shape of individual solution component 

1185 """ 

1186 return self.init[0][1:] 

1187 

1188 @property 

1189 def ndim(self): 

1190 return len(self.axes) 

1191 

1192 @property 

1193 def ncomponents(self): 

1194 return len(self.components) 

1195 

1196 @property 

1197 def V(self): 

1198 """ 

1199 Get domain volume 

1200 """ 

1201 return np.prod([me.L for me in self.axes]) 

1202 

1203 def add_axis(self, base, *args, **kwargs): 

1204 """ 

1205 Add an axis to the domain by deciding on suitable 1D base. 

1206 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

1207 bases for possible arguments. 

1208 

1209 Args: 

1210 base (str): 1D spectral method 

1211 """ 

1212 kwargs['useGPU'] = self.useGPU 

1213 

1214 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

1215 self.axes.append(ChebychevHelper(*args, **kwargs)) 

1216 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

1217 self.axes.append(FFTHelper(*args, **kwargs)) 

1218 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

1219 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

1220 else: 

1221 raise NotImplementedError(f'{base=!r} is not implemented!') 

1222 self.axes[-1].xp = self.xp 

1223 self.axes[-1].sparse_lib = self.sparse_lib 

1224 

1225 def add_component(self, name): 

1226 """ 

1227 Add solution component(s). 

1228 

1229 Args: 

1230 name (str or list of strings): Name(s) of component(s) 

1231 """ 

1232 if type(name) in [list, tuple]: 

1233 for me in name: 

1234 self.add_component(me) 

1235 elif type(name) in [str]: 

1236 if name in self.components: 

1237 raise Exception(f'{name=!r} is already added to this problem!') 

1238 self.components.append(name) 

1239 else: 

1240 raise NotImplementedError 

1241 

1242 def index(self, name): 

1243 """ 

1244 Get the index of component `name`. 

1245 

1246 Args: 

1247 name (str or list of strings): Name(s) of component(s) 

1248 

1249 Returns: 

1250 int: Index of the component 

1251 """ 

1252 if type(name) in [str, int]: 

1253 return self.components.index(name) 

1254 elif type(name) in [list, tuple]: 

1255 return (self.index(me) for me in name) 

1256 else: 

1257 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

1258 

1259 def get_empty_operator_matrix(self, diag=False): 

1260 """ 

1261 Return a matrix of operators to be filled with the connections between the solution components. 

1262 

1263 Args: 

1264 diag (bool): Whether operator is block-diagonal 

1265 

1266 Returns: 

1267 list containing sparse zeros 

1268 """ 

1269 S = len(self.components) 

1270 O = self.get_Id() * 0 

1271 if diag: 

1272 return [O for _ in range(S)] 

1273 else: 

1274 return [[O for _ in range(S)] for _ in range(S)] 

1275 

1276 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

1277 """ 

1278 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

1279 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

1280 Typically, the last line that does not contain a BC is the best choice. 

1281 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1282 

1283 Args: 

1284 axis (int): Axis you want to add the BC to 

1285 kind (str): kind of BC, e.g. Dirichlet 

1286 line (int): Line you want the BC to go in 

1287 scalar (bool): Put the BC in all space positions in the other direction 

1288 

1289 Returns: 

1290 sparse matrix containing the BC 

1291 """ 

1292 sp = scipy.sparse 

1293 

1294 base = self.axes[axis] 

1295 

1296 BC = sp.eye(base.N).tolil() * 0 

1297 if self.useGPU: 

1298 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1299 else: 

1300 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1301 

1302 ndim = len(self.axes) 

1303 if ndim == 1: 

1304 mat = self.sparse_lib.csc_matrix(BC) 

1305 elif ndim == 2: 

1306 axis2 = (axis + 1) % ndim 

1307 

1308 if scalar: 

1309 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1310 else: 

1311 _Id = self.axes[axis2].get_Id() 

1312 

1313 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1314 

1315 mats = [ 

1316 None, 

1317 ] * ndim 

1318 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1319 mats[axis2] = Id 

1320 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) 

1321 elif ndim == 3: 

1322 mats = [ 

1323 None, 

1324 ] * ndim 

1325 

1326 for ax in range(ndim): 

1327 if ax == axis: 

1328 continue 

1329 

1330 if scalar: 

1331 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1))) 

1332 else: 

1333 _Id = self.axes[ax].get_Id() 

1334 

1335 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax) 

1336 

1337 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1338 

1339 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) 

1340 else: 

1341 raise NotImplementedError( 

1342 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!' 

1343 ) 

1344 mat = self.eliminate_zeros(mat) 

1345 return mat 

1346 

1347 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1348 """ 

1349 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1350 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1351 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1352 

1353 Args: 

1354 component (str): Name of the component the BC should act on 

1355 equation (str): Name of the equation for the component you want to put the BC in 

1356 axis (int): Axis you want to add the BC to 

1357 kind (str): kind of BC, e.g. Dirichlet 

1358 v: Value of the BC 

1359 line (int): Line you want the BC to go in 

1360 scalar (bool): Put the BC in all space positions in the other direction 

1361 """ 

1362 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1363 _BC = self.eliminate_zeros(_BC) 

1364 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1365 

1366 if scalar: 

1367 slices = [self.index(equation)] + [ 

1368 0, 

1369 ] * self.ndim 

1370 slices[axis + 1] = line 

1371 else: 

1372 slices = ( 

1373 [self.index(equation)] 

1374 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1375 + [line] 

1376 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1377 ) 

1378 N = self.axes[axis].N 

1379 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]: 

1380 self.BC_rhs_mask[(*slices,)] = False 

1381 

1382 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1383 """ 

1384 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1385 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1386 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1387 

1388 Args: 

1389 component (str): Name of the component the BC should act on 

1390 equation (str): Name of the equation for the component you want to put the BC in 

1391 axis (int): Axis you want to add the BC to 

1392 kind (str): kind of BC, e.g. Dirichlet 

1393 v: Value of the BC 

1394 line (int): Line you want the BC to go in 

1395 scalar (bool): Put the BC in all space positions in the other direction 

1396 """ 

1397 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1398 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1399 self.full_BCs += [ 

1400 { 

1401 'component': component, 

1402 'equation': equation, 

1403 'axis': axis, 

1404 'kind': kind, 

1405 'v': v, 

1406 'line': line, 

1407 'scalar': scalar, 

1408 **kwargs, 

1409 } 

1410 ] 

1411 

1412 if scalar: 

1413 slices = [self.index(equation)] + [ 

1414 0, 

1415 ] * self.ndim 

1416 slices[axis + 1] = line 

1417 if self.comm: 

1418 if self.comm.rank == 0: 

1419 self.BC_rhs_mask[(*slices,)] = True 

1420 else: 

1421 self.BC_rhs_mask[(*slices,)] = True 

1422 else: 

1423 slices = [self.index(equation), *self.global_slice(True)] 

1424 N = self.axes[axis].N 

1425 if (N + line) % N in self.get_indices(True)[axis]: 

1426 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1427 self.BC_rhs_mask[(*slices,)] = True 

1428 

1429 def setup_BCs(self): 

1430 """ 

1431 Convert the list of lists of BCs to the boundary condition operator. 

1432 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1433 condition. This method sets up a suitable sparse matrix to do this. 

1434 """ 

1435 sp = self.sparse_lib 

1436 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1437 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1438 

1439 diags = self.xp.ones(self.BCs.shape[0]) 

1440 diags[self.BC_zero_index] = 0 

1441 self.BC_line_zero_matrix = sp.diags(diags).tocsc() 

1442 

1443 # prepare BCs in spectral space to easily add to the RHS 

1444 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1445 self.rhs_BCs_hat = self.transform(rhs_BCs).view(self.xp.ndarray) 

1446 del self.BC_rhs_mask 

1447 

1448 def check_BCs(self, u): 

1449 """ 

1450 Check that the solution satisfies the boundary conditions 

1451 

1452 Args: 

1453 u: The solution you want to check 

1454 """ 

1455 assert self.ndim < 3 

1456 for axis in range(self.ndim): 

1457 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1458 

1459 if len(BCs) > 0: 

1460 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1461 for BC in BCs: 

1462 kwargs = { 

1463 key: value 

1464 for key, value in BC.items() 

1465 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1466 } 

1467 

1468 if axis == 0: 

1469 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1470 elif axis == 1: 

1471 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1472 want = BC['v'] 

1473 assert self.xp.allclose( 

1474 get, want 

1475 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1476 

1477 def put_BCs_in_matrix(self, A): 

1478 """ 

1479 Put the boundary conditions in a matrix by replacing rows with BCs. 

1480 """ 

1481 return self.BC_line_zero_matrix @ A + self.BCs 

1482 

1483 def put_BCs_in_rhs_hat(self, rhs_hat): 

1484 """ 

1485 Put the BCs in the right hand side in spectral space for solving. 

1486 This function needs no transforms and caches a mask for faster subsequent use. 

1487 

1488 Args: 

1489 rhs_hat: Right hand side in spectral space 

1490 

1491 Returns: 

1492 rhs in spectral space with BCs 

1493 """ 

1494 if not hasattr(self, '_rhs_hat_zero_mask'): 

1495 """ 

1496 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1497 by the boundary conditions. The mask is then cached. 

1498 """ 

1499 self._rhs_hat_zero_mask = self.newDistArray(forward_output=True).astype(bool).view(self.xp.ndarray) 

1500 

1501 for axis in range(self.ndim): 

1502 for bc in self.full_BCs: 

1503 if axis == bc['axis']: 

1504 slices = [self.index(bc['equation']), *self.global_slice(True)] 

1505 N = self.axes[axis].N 

1506 line = bc['line'] 

1507 if (N + line) % N in self.get_indices(True)[axis]: 

1508 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1509 self._rhs_hat_zero_mask[(*slices,)] = True 

1510 

1511 rhs_hat[self._rhs_hat_zero_mask] = 0 

1512 return rhs_hat + self.rhs_BCs_hat 

1513 

1514 def put_BCs_in_rhs(self, rhs): 

1515 """ 

1516 Put the BCs in the right hand side for solving. 

1517 This function will transform along each axis individually and add all BCs in that axis. 

1518 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1519 

1520 Args: 

1521 rhs: Right hand side in physical space 

1522 

1523 Returns: 

1524 rhs in physical space with BCs 

1525 """ 

1526 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1527 

1528 ndim = self.ndim 

1529 

1530 for axis in range(ndim): 

1531 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1532 

1533 for bc in self.full_BCs: 

1534 

1535 if axis == bc['axis']: 

1536 _slice = [self.index(bc['equation']), *self.global_slice(True)] 

1537 

1538 N = self.axes[axis].N 

1539 line = bc['line'] 

1540 if (N + line) % N in self.get_indices(True)[axis]: 

1541 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1542 _rhs_hat[(*_slice,)] = bc['v'] 

1543 

1544 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1545 

1546 return rhs 

1547 

1548 def add_equation_lhs(self, A, equation, relations): 

1549 """ 

1550 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1551 that you will convert to an operator later. 

1552 

1553 Example: 

1554 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1555 preconditioning: 

1556 

1557 .. code-block:: python 

1558 helper = SpectralHelper() 

1559 

1560 helper.add_axis(base='chebychev', N=8) 

1561 helper.add_component(['u', 'ux']) 

1562 helper.setup_fft() 

1563 

1564 I = helper.get_Id() 

1565 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1566 T2U = helper.get_basis_change_matrix('T2U') 

1567 

1568 L_lhs = { 

1569 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1570 'u': {'ux': -(T2U @ Dx)}, 

1571 } 

1572 

1573 operator = helper.get_empty_operator_matrix() 

1574 for line, equation in L_lhs.items(): 

1575 helper.add_equation_lhs(operator, line, equation) 

1576 

1577 L = helper.convert_operator_matrix_to_operator(operator) 

1578 

1579 Args: 

1580 A (list of lists of sparse matrices): The operator to be 

1581 equation (str): The equation of the component you want this in 

1582 relations: (dict): Relations between quantities 

1583 """ 

1584 for k, v in relations.items(): 

1585 A[self.index(equation)][self.index(k)] = v 

1586 

1587 def eliminate_zeros(self, A): 

1588 """ 

1589 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat. 

1590 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`. 

1591 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU. 

1592 

1593 Args: 

1594 A: sparse matrix to be pruned 

1595 

1596 Returns: 

1597 CSC sparse matrix 

1598 """ 

1599 if self.useGPU: 

1600 A = A.get() 

1601 A = A.tocsc() 

1602 A.eliminate_zeros() 

1603 if self.useGPU: 

1604 A = self.sparse_lib.csc_matrix(A) 

1605 return A 

1606 

1607 def convert_operator_matrix_to_operator(self, M): 

1608 """ 

1609 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1610 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1611 

1612 Args: 

1613 M (list of lists of sparse matrices): The operator to be 

1614 

1615 Returns: 

1616 sparse linear operator 

1617 """ 

1618 if len(self.components) == 1: 

1619 op = M[0][0] 

1620 else: 

1621 op = self.sparse_lib.bmat(M, format='csc') 

1622 

1623 op = self.eliminate_zeros(op) 

1624 return op 

1625 

1626 def get_wavenumbers(self): 

1627 """ 

1628 Get grid in spectral space 

1629 """ 

1630 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))] 

1631 return self.xp.meshgrid(*grids, indexing='ij') 

1632 

1633 def get_grid(self, forward_output=False): 

1634 """ 

1635 Get grid in physical space 

1636 """ 

1637 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1638 return self.xp.meshgrid(*grids, indexing='ij') 

1639 

1640 def get_indices(self, forward_output=True): 

1641 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1642 

1643 @cache 

1644 def get_pfft(self, axes=None, padding=None, grid=None): 

1645 if self.ndim == 1 or self.comm is None: 

1646 return None 

1647 from mpi4py_fft import PFFT, newDistArray 

1648 

1649 axes = tuple(i for i in range(self.ndim)) if axes is None else axes 

1650 padding = list(padding if padding else [1.0 for _ in range(self.ndim)]) 

1651 

1652 def no_transform(u, *args, **kwargs): 

1653 return u 

1654 

1655 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)} 

1656 for i in axes: 

1657 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform) 

1658 

1659 # "transform" all axes to ensure consistent shapes. 

1660 # Transform non-distributable axes last to ensure they are aligned 

1661 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes)) 

1662 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted( 

1663 [axis for axis in _axes if self.axes[axis].distributable] 

1664 + [axis for axis in range(self.ndim) if axis not in _axes] 

1665 ) 

1666 

1667 pfft = PFFT( 

1668 comm=self.comm, 

1669 shape=self.global_shape[1:], 

1670 axes=_axes, # TODO: control the order of the transforms better 

1671 dtype='D', 

1672 collapse=False, 

1673 backend=self.fft_backend, 

1674 comm_backend=self.fft_comm_backend, 

1675 padding=padding, 

1676 transforms=transforms, 

1677 grid=grid, 

1678 ) 

1679 

1680 # do a transform to do the planning 

1681 _u = newDistArray(pfft, forward_output=False) 

1682 pfft.backward(pfft.forward(_u)) 

1683 return pfft 

1684 

1685 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1686 """ 

1687 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1688 

1689 Args: 

1690 axes (tuple): Axes you want to transform over 

1691 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1692 padding (tuple): Padding for dealiasing 

1693 shape (tuple): Shape of the transform 

1694 

1695 Returns: 

1696 transform 

1697 """ 

1698 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1699 shape = self.global_shape[1:] if shape is None else shape 

1700 padding = ( 

1701 [ 

1702 1, 

1703 ] 

1704 * self.ndim 

1705 if padding is None 

1706 else padding 

1707 ) 

1708 key = (axes, direction, tuple(padding), tuple(shape)) 

1709 

1710 if key not in self.fft_cache.keys(): 

1711 if self.comm is None: 

1712 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1713 

1714 if direction == 'forward': 

1715 self.fft_cache[key] = self.xp.fft.fftn 

1716 elif direction == 'backward': 

1717 self.fft_cache[key] = self.xp.fft.ifftn 

1718 elif direction == 'object': 

1719 self.fft_cache[key] = None 

1720 else: 

1721 if direction == 'object': 

1722 from mpi4py_fft import PFFT 

1723 

1724 _fft = PFFT( 

1725 comm=self.comm, 

1726 shape=shape, 

1727 axes=sorted(axes), 

1728 dtype='D', 

1729 collapse=False, 

1730 backend=self.fft_backend, 

1731 comm_backend=self.fft_comm_backend, 

1732 padding=padding, 

1733 ) 

1734 else: 

1735 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1736 

1737 if direction == 'forward': 

1738 self.fft_cache[key] = _fft.forward 

1739 elif direction == 'backward': 

1740 self.fft_cache[key] = _fft.backward 

1741 elif direction == 'object': 

1742 self.fft_cache[key] = _fft 

1743 

1744 return self.fft_cache[key] 

1745 

1746 def local_slice(self, forward_output=True): 

1747 if self.fft_obj: 

1748 return self.get_pfft().local_slice(forward_output=forward_output) 

1749 else: 

1750 return [slice(0, me.N) for me in self.axes] 

1751 

1752 def global_slice(self, forward_output=True): 

1753 if self.fft_obj: 

1754 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)] 

1755 else: 

1756 return self.local_slice(forward_output=forward_output) 

1757 

1758 def setup_fft(self, real_spectral_coefficients=False): 

1759 """ 

1760 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1761 This must also be called before setting up any BCs. 

1762 

1763 Args: 

1764 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1765 """ 

1766 if len(self.components) == 0: 

1767 self.add_component('u') 

1768 

1769 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1770 

1771 axes = tuple(i for i in range(len(self.axes))) 

1772 self.fft_obj = self.get_pfft(axes=axes) 

1773 

1774 self.init = ( 

1775 np.empty(shape=self.global_shape)[ 

1776 ( 

1777 ..., 

1778 *self.local_slice(False), 

1779 ) 

1780 ].shape, 

1781 self.comm, 

1782 np.dtype('float'), 

1783 ) 

1784 self.init_physical = ( 

1785 np.empty(shape=self.global_shape)[ 

1786 ( 

1787 ..., 

1788 *self.local_slice(False), 

1789 ) 

1790 ].shape, 

1791 self.comm, 

1792 np.dtype('float'), 

1793 ) 

1794 self.init_forward = ( 

1795 np.empty(shape=self.global_shape)[ 

1796 ( 

1797 ..., 

1798 *self.local_slice(True), 

1799 ) 

1800 ].shape, 

1801 self.comm, 

1802 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1803 ) 

1804 

1805 self.BC_mat = self.get_empty_operator_matrix() 

1806 self.BC_rhs_mask = self.newDistArray().astype(bool) 

1807 

1808 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False): 

1809 """ 

1810 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but 

1811 takes care of all the solution components in the tensor. 

1812 """ 

1813 if self.comm is None: 

1814 return self.xp.zeros(self.init[0], dtype=self.init[2]) 

1815 from mpi4py_fft.distarray import DistArray 

1816 

1817 pfft = pfft if pfft else self.get_pfft() 

1818 if pfft is None: 

1819 if forward_output: 

1820 return self.u_init_forward 

1821 else: 

1822 return self.u_init 

1823 

1824 global_shape = pfft.global_shape(forward_output) 

1825 p0 = pfft.pencil[forward_output] 

1826 if forward_output is True: 

1827 dtype = pfft.forward.output_array.dtype 

1828 else: 

1829 dtype = pfft.forward.input_array.dtype 

1830 global_shape = (self.ncomponents,) * rank + global_shape 

1831 

1832 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: 

1833 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls 

1834 else: 

1835 darraycls = DistArray 

1836 

1837 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank) 

1838 return z.v if view else z 

1839 

1840 def infer_alignment(self, u, forward_output, padding=None, **kwargs): 

1841 if self.comm is None: 

1842 return [0] 

1843 

1844 def _alignment(pfft): 

1845 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1846 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]] 

1847 return _aligned_axes 

1848 

1849 if padding is None: 

1850 pfft = self.get_pfft(**kwargs) 

1851 aligned_axes = _alignment(pfft) 

1852 else: 

1853 if self.ndim == 2: 

1854 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)] 

1855 elif self.ndim == 3: 

1856 padding_options = [ 

1857 (1.0, 1.0, padding[2]), 

1858 (1.0, padding[1], 1.0), 

1859 (padding[0], 1.0, 1.0), 

1860 (1.0, padding[1], padding[2]), 

1861 (padding[0], 1.0, padding[2]), 

1862 (padding[0], padding[1], 1.0), 

1863 padding, 

1864 (1.0, 1.0, 1.0), 

1865 ] 

1866 else: 

1867 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!') 

1868 for _padding in padding_options: 

1869 pfft = self.get_pfft(padding=_padding, **kwargs) 

1870 aligned_axes = _alignment(pfft) 

1871 if len(aligned_axes) > 0: 

1872 self.logger.debug( 

1873 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}' 

1874 ) 

1875 break 

1876 

1877 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!' 

1878 return aligned_axes 

1879 

1880 def redistribute(self, u, axis, forward_output, **kwargs): 

1881 if self.comm is None: 

1882 return u 

1883 

1884 pfft = self.get_pfft(**kwargs) 

1885 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1886 

1887 if 'Dist' in type(u).__name__ and False: 

1888 try: 

1889 u.redistribute(out=_arr) 

1890 return _arr 

1891 except AssertionError: 

1892 pass 

1893 

1894 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs) 

1895 for alignment in u_alignment: 

1896 _arr = _arr.redistribute(alignment) 

1897 if _arr.shape == u.shape: 

1898 _arr[...] = u 

1899 return _arr.redistribute(axis) 

1900 

1901 raise Exception( 

1902 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}' 

1903 ) 

1904 

1905 def transform(self, u, *args, axes=None, padding=None, **kwargs): 

1906 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1907 

1908 if pfft is None: 

1909 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1910 u_hat = u.copy() 

1911 for i in axes: 

1912 _axis = 1 + i if i >= 0 else i 

1913 u_hat = self.axes[i].transform(u_hat, axes=(_axis,)) 

1914 return u_hat 

1915 

1916 _in = self.newDistArray(pfft, forward_output=False, rank=1) 

1917 _out = self.newDistArray(pfft, forward_output=True, rank=1) 

1918 

1919 if _in.shape == u.shape: 

1920 _in[...] = u 

1921 else: 

1922 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs) 

1923 

1924 for i in range(self.ncomponents): 

1925 pfft.forward(_in[i], _out[i], normalize=False) 

1926 

1927 if padding is not None: 

1928 _out /= np.prod(padding) 

1929 return _out 

1930 

1931 def itransform(self, u, *args, axes=None, padding=None, **kwargs): 

1932 if padding is not None: 

1933 assert all( 

1934 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim) 

1935 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer' 

1936 

1937 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1938 if pfft is None: 

1939 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1940 u_hat = u.copy() 

1941 for i in axes: 

1942 _axis = 1 + i if i >= 0 else i 

1943 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,)) 

1944 return u_hat 

1945 

1946 _in = self.newDistArray(pfft, forward_output=True, rank=1) 

1947 _out = self.newDistArray(pfft, forward_output=False, rank=1) 

1948 

1949 if _in.shape == u.shape: 

1950 _in[...] = u 

1951 else: 

1952 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs) 

1953 

1954 for i in range(self.ncomponents): 

1955 pfft.backward(_in[i], _out[i], normalize=True) 

1956 

1957 if padding is not None: 

1958 _out *= np.prod(padding) 

1959 return _out 

1960 

1961 def get_local_slice_of_1D_matrix(self, M, axis): 

1962 """ 

1963 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1964 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can 

1965 use this method to get the part corresponding to the modes carried by this rank. 

1966 

1967 Args: 

1968 M (sparse matrix): Global 1D matrix you want to get the local version of 

1969 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. 

1970 

1971 Returns: 

1972 sparse local matrix 

1973 """ 

1974 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]] 

1975 

1976 def expand_matrix_ND(self, matrix, aligned): 

1977 sp = self.sparse_lib 

1978 axes = np.delete(np.arange(self.ndim), aligned) 

1979 ndim = len(axes) + 1 

1980 

1981 if ndim == 1: 

1982 mat = matrix 

1983 elif ndim == 2: 

1984 axis = axes[0] 

1985 I1D = sp.eye(self.axes[axis].N) 

1986 

1987 mats = [None] * ndim 

1988 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1989 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1990 

1991 mat = sp.kron(*mats) 

1992 elif ndim == 3: 

1993 

1994 mats = [None] * ndim 

1995 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1996 for axis in axes: 

1997 I1D = sp.eye(self.axes[axis].N) 

1998 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1999 

2000 mat = sp.kron(mats[0], sp.kron(*mats[1:])) 

2001 

2002 else: 

2003 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') 

2004 

2005 mat = self.eliminate_zeros(mat) 

2006 return mat 

2007 

2008 def get_filter_matrix(self, axis, **kwargs): 

2009 """ 

2010 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

2011 admissible. 

2012 

2013 Returns: 

2014 sparse bandpass matrix 

2015 """ 

2016 if self.ndim == 1: 

2017 return self.axes[0].get_filter_matrix(**kwargs) 

2018 

2019 mats = [base.get_Id() for base in self.axes] 

2020 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

2021 return self.sparse_lib.kron(*mats) 

2022 

2023 def get_differentiation_matrix(self, axes, **kwargs): 

2024 """ 

2025 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

2026 

2027 Args: 

2028 axes (tuple): Axes along which to differentiate. 

2029 

2030 Returns: 

2031 sparse differentiation matrix 

2032 """ 

2033 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) 

2034 for axis in axes[1:]: 

2035 _D = self.axes[axis].get_differentiation_matrix(**kwargs) 

2036 D = D @ self.expand_matrix_ND(_D, axis) 

2037 

2038 self.logger.debug(f'Set up differentiation matrix along axes {axes} with kwargs {kwargs}') 

2039 return D 

2040 

2041 def get_integration_matrix(self, axes): 

2042 """ 

2043 Get integration matrix to integrate along specified axis. 

2044 

2045 Args: 

2046 axes (tuple): Axes along which to integrate over. 

2047 

2048 Returns: 

2049 sparse integration matrix 

2050 """ 

2051 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) 

2052 for axis in axes[1:]: 

2053 _S = self.axes[axis].get_integration_matrix() 

2054 S = S @ self.expand_matrix_ND(_S, axis) 

2055 

2056 return S 

2057 

2058 def get_Id(self): 

2059 """ 

2060 Get identity matrix 

2061 

2062 Returns: 

2063 sparse identity matrix 

2064 """ 

2065 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) 

2066 for axis in range(1, self.ndim): 

2067 _I = self.axes[axis].get_Id() 

2068 I = I @ self.expand_matrix_ND(_I, axis) 

2069 return I 

2070 

2071 def get_Dirichlet_recombination_matrix(self, axis=-1): 

2072 """ 

2073 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

2074 

2075 Args: 

2076 axis (int): Axis you discretized with Chebychev 

2077 

2078 Returns: 

2079 sparse matrix 

2080 """ 

2081 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

2082 return self.expand_matrix_ND(C1D, axis) 

2083 

2084 def get_basis_change_matrix(self, axes=None, **kwargs): 

2085 """ 

2086 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

2087 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

2088 

2089 Args: 

2090 axes (tuple): Axes along which to change basis. 

2091 

2092 Returns: 

2093 sparse basis change matrix 

2094 """ 

2095 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

2096 

2097 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) 

2098 for axis in axes[1:]: 

2099 _C = self.axes[axis].get_basis_change_matrix(**kwargs) 

2100 C = C @ self.expand_matrix_ND(_C, axis) 

2101 

2102 self.logger.debug(f'Set up basis change matrix along axes {axes} with kwargs {kwargs}') 

2103 return C