Coverage for pySDC / helpers / spectral_helper.py: 83%

864 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 07:06 +0000

1import numpy as np 

2import scipy 

3from pySDC.implementations.datatype_classes.mesh import mesh 

4from scipy.special import factorial 

5from functools import partial, wraps 

6import logging 

7 

8 

9def cache(func): 

10 """ 

11 Decorator for caching return values of functions. 

12 This is very similar to `functools.cache`, but without the memory leaks (see 

13 https://docs.astral.sh/ruff/rules/cached-instance-method/). 

14 

15 Example: 

16 

17 .. code-block:: python 

18 

19 num_calls = 0 

20 

21 @cache 

22 def increment(x): 

23 num_calls += 1 

24 return x + 1 

25 

26 increment(0) # returns 1, num_calls = 1 

27 increment(1) # returns 2, num_calls = 2 

28 increment(0) # returns 1, num_calls = 2 

29 

30 

31 Args: 

32 func (function): The function you want to cache the return value of 

33 

34 Returns: 

35 return value of func 

36 """ 

37 attr_cache = f"_{func.__name__}_cache" 

38 

39 @wraps(func) 

40 def wrapper(self, *args, **kwargs): 

41 if not hasattr(self, attr_cache): 

42 setattr(self, attr_cache, {}) 

43 

44 cache = getattr(self, attr_cache) 

45 

46 key = (args, frozenset(kwargs.items())) 

47 if key in cache: 

48 return cache[key] 

49 result = func(self, *args, **kwargs) 

50 cache[key] = result 

51 return result 

52 

53 return wrapper 

54 

55 

56class vkFFT(object): 

57 """ 

58 pyVkFFT FFT backend. 

59 The special feature of vkFFT is fast DCT on GPU with cached plans. 

60 """ 

61 

62 cached_plans = {} 

63 

64 @staticmethod 

65 def is_complex(x): 

66 return 'complex' in str(x.dtype) 

67 

68 @staticmethod 

69 def get_plan(transform_type, shape, dtype, axes, norm): 

70 from pyvkfft.cuda import VkFFTApp 

71 

72 assert norm == 'backward' 

73 

74 key = f'{transform_type=}, {shape=}, {dtype=}, {axes=}, {norm=}' 

75 

76 if key not in vkFFT.cached_plans.keys(): 

77 

78 kwargs = {} 

79 

80 if transform_type == 'dct': 

81 kwargs['dct'] = 2 

82 

83 vkFFT.cached_plans[key] = VkFFTApp(shape, dtype, len(axes), axes=axes, norm=norm, **kwargs) 

84 

85 logger = logging.getLogger(name='VkFFT') 

86 logger.debug(f'Cached plan for VkFFT: {key}') 

87 return vkFFT.cached_plans[key] 

88 

89 @staticmethod 

90 def fftn(x, s=None, axes=None, norm='backward', overwrite_x=False): 

91 assert not overwrite_x # for consistent interface with scipy 

92 assert norm == 'backward' # for consistent interface with scipy 

93 plan = vkFFT.get_plan( 

94 transform_type='fft', 

95 shape=x.shape, 

96 dtype=x.dtype, 

97 axes=axes, 

98 norm=norm, 

99 ) 

100 _x = x.copy() + 0j # cast to complex 

101 plan.fft(_x) 

102 return _x 

103 

104 @staticmethod 

105 def ifftn(x, s=None, axes=None, norm='forward', overwrite_x=False): 

106 assert norm == 'forward' 

107 assert not overwrite_x # for consistent interface with scipy 

108 

109 norm = 'backward' 

110 plan = vkFFT.get_plan( 

111 transform_type='fft', 

112 shape=x.shape, 

113 dtype=x.dtype, 

114 axes=axes, 

115 norm=norm, 

116 ) 

117 _x = x.copy() + 0j # promote to complex 

118 plan.ifft(_x) 

119 return _x * sum(x.shape[i] for i in axes) 

120 

121 @staticmethod 

122 def dctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False): 

123 assert type == 2 # for consistent interface with scipy 

124 assert not overwrite_x # for consistent interface with scipy 

125 

126 is_complex = vkFFT.is_complex(x) 

127 

128 dtype = x.dtype if not is_complex else x.real.dtype 

129 

130 plan = vkFFT.get_plan( 

131 transform_type='dct', 

132 shape=x.shape, 

133 dtype=dtype, 

134 axes=axes, 

135 norm=norm, 

136 ) 

137 

138 if is_complex: 

139 x_real = x.real.copy() 

140 x_imag = x.imag.copy() 

141 

142 plan.fft(x_real) 

143 plan.fft(x_imag) 

144 

145 return x_real + 1j * x_imag 

146 else: 

147 _x = x.copy() 

148 plan.fft(x) 

149 return x 

150 

151 @staticmethod 

152 def idctn(x, type=2, s=None, axes=None, norm=None, overwrite_x=False): 

153 assert type == 2 # for consistent interface with scipy 

154 assert not overwrite_x # for consistent interface with scipy 

155 

156 is_complex = vkFFT.is_complex(x) 

157 dtype = x.dtype if not is_complex else x.real.dtype 

158 

159 plan = vkFFT.get_plan( 

160 transform_type='dct', 

161 shape=x.shape, 

162 dtype=dtype, 

163 axes=axes, 

164 norm=norm, 

165 ) 

166 

167 if is_complex: 

168 x_real = x.real.copy() 

169 x_imag = x.imag.copy() 

170 

171 plan.ifft(x_real) 

172 plan.ifft(x_imag) 

173 

174 return x_real + 1j * x_imag 

175 else: 

176 _x = x.copy() 

177 plan.ifft(x) 

178 return x 

179 

180 

181class SpectralHelper1D: 

182 """ 

183 Abstract base class for 1D spectral discretizations. Defines a common interface with parameters and functions that 

184 all bases need to have. 

185 

186 When implementing new bases, please take care to use the modules that are supplied as class attributes to enable 

187 the code for GPUs. 

188 

189 Attributes: 

190 N (int): Resolution 

191 x0 (float): Coordinate of left boundary 

192 x1 (float): Coordinate of right boundary 

193 L (float): Length of the domain 

194 useGPU (bool): Whether to use GPUs 

195 

196 """ 

197 

198 fft_lib = scipy.fft 

199 sparse_lib = scipy.sparse 

200 linalg = scipy.sparse.linalg 

201 xp = np 

202 distributable = False 

203 

204 def __init__(self, N, x0=None, x1=None, useGPU=False, useFFTW=False): 

205 """ 

206 Constructor 

207 

208 Args: 

209 N (int): Resolution 

210 x0 (float): Coordinate of left boundary 

211 x1 (float): Coordinate of right boundary 

212 useGPU (bool): Whether to use GPUs 

213 useFFTW (bool): Whether to use FFTW for the transforms 

214 """ 

215 self.N = N 

216 self.x0 = x0 

217 self.x1 = x1 

218 self.L = x1 - x0 

219 self.useGPU = useGPU 

220 self.plans = {} 

221 self.logger = logging.getLogger(name=type(self).__name__) 

222 

223 if useGPU: 

224 self.setup_GPU() 

225 self.logger.debug('Set up for GPU') 

226 else: 

227 self.setup_CPU(useFFTW=useFFTW) 

228 

229 if useGPU and useFFTW: 

230 raise ValueError('Please run either on GPUs or with FFTW, not both!') 

231 

232 @classmethod 

233 def setup_GPU(cls): 

234 """switch to GPU modules""" 

235 import cupy as cp 

236 import cupyx.scipy.sparse as sparse_lib 

237 import cupyx.scipy.sparse.linalg as linalg 

238 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

239 

240 cls.xp = cp 

241 cls.sparse_lib = sparse_lib 

242 cls.linalg = linalg 

243 cls.fft_lib = vkFFT 

244 

245 @classmethod 

246 def setup_CPU(cls, useFFTW=False): 

247 """switch to CPU modules""" 

248 

249 cls.xp = np 

250 cls.sparse_lib = scipy.sparse 

251 cls.linalg = scipy.sparse.linalg 

252 

253 if useFFTW: 

254 from mpi4py_fft import fftw 

255 

256 cls.fft_backend = 'fftw' 

257 cls.fft_lib = fftw 

258 else: 

259 cls.fft_backend = 'scipy' 

260 cls.fft_lib = scipy.fft 

261 

262 cls.fft_comm_backend = 'MPI' 

263 cls.dtype = mesh 

264 

265 def get_Id(self): 

266 """ 

267 Get identity matrix 

268 

269 Returns: 

270 sparse diagonal identity matrix 

271 """ 

272 return self.sparse_lib.eye(self.N) 

273 

274 def get_zero(self): 

275 """ 

276 Get a matrix with all zeros of the correct size. 

277 

278 Returns: 

279 sparse matrix with zeros everywhere 

280 """ 

281 return 0 * self.get_Id() 

282 

283 def get_differentiation_matrix(self): 

284 raise NotImplementedError() 

285 

286 def get_integration_matrix(self): 

287 raise NotImplementedError() 

288 

289 def get_integration_weights(self): 

290 """Weights for integration across entire domain""" 

291 raise NotImplementedError() 

292 

293 def get_wavenumbers(self): 

294 """ 

295 Get the grid in spectral space 

296 """ 

297 raise NotImplementedError 

298 

299 def get_empty_operator_matrix(self, S, O): 

300 """ 

301 Return a matrix of operators to be filled with the connections between the solution components. 

302 

303 Args: 

304 S (int): Number of components in the solution 

305 O (sparse matrix): Zero matrix used for initialization 

306 

307 Returns: 

308 list of lists containing sparse zeros 

309 """ 

310 return [[O for _ in range(S)] for _ in range(S)] 

311 

312 def get_basis_change_matrix(self, *args, **kwargs): 

313 """ 

314 Some spectral discretization change the basis during differentiation. This method can be used to transfer 

315 between the various bases. 

316 

317 This method accepts arbitrary arguments that may not be used in order to provide an easy interface for multi- 

318 dimensional bases. For instance, you may combine an FFT discretization with an ultraspherical discretization. 

319 The FFT discretization will always be in the same base, but the ultraspherical discretization uses a different 

320 base for every derivative. You can then ask all bases for transfer matrices from one ultraspherical derivative 

321 base to the next. The FFT discretization will ignore this and return an identity while the ultraspherical 

322 discretization will return the desired matrix. After a Kronecker product, you get the 2D version of the matrix 

323 you want. This is what the `SpectralHelper` does when you call the method of the same name on it. 

324 

325 Returns: 

326 sparse bases change matrix 

327 """ 

328 return self.sparse_lib.eye(self.N) 

329 

330 def get_BC(self, kind): 

331 """ 

332 To facilitate boundary conditions (BCs) we use either a basis where all functions satisfy the BCs automatically, 

333 as is the case in FFT basis for periodic BCs, or boundary bordering. In boundary bordering, specific lines in 

334 the matrix are replaced by the boundary conditions as obtained by this method. 

335 

336 Args: 

337 kind (str): The type of BC you want to implement please refer to the implementations of this method in the 

338 individual 1D bases for what is implemented 

339 

340 Returns: 

341 self.xp.array: Boundary condition 

342 """ 

343 raise NotImplementedError(f'No boundary conditions of {kind=!r} implemented!') 

344 

345 def get_filter_matrix(self, kmin=0, kmax=None): 

346 """ 

347 Get a bandpass filter. 

348 

349 Args: 

350 kmin (int): Lower limit of the bandpass filter 

351 kmax (int): Upper limit of the bandpass filter 

352 

353 Returns: 

354 sparse matrix 

355 """ 

356 

357 k = abs(self.get_wavenumbers()) 

358 

359 kmax = max(k) if kmax is None else kmax 

360 

361 mask = self.xp.logical_or(k >= kmax, k < kmin) 

362 

363 if self.useGPU: 

364 Id = self.get_Id().get() 

365 else: 

366 Id = self.get_Id() 

367 F = Id.tolil() 

368 F[:, mask] = 0 

369 return F.tocsc() 

370 

371 def get_1dgrid(self): 

372 """ 

373 Get the grid in physical space 

374 

375 Returns: 

376 self.xp.array: Grid 

377 """ 

378 raise NotImplementedError 

379 

380 

381class ChebychevHelper(SpectralHelper1D): 

382 """ 

383 The Chebychev base consists of special kinds of polynomials, with the main advantage that you can easily transform 

384 between physical and spectral space by discrete cosine transform. 

385 The differentiation in the Chebychev T base is dense, but can be preconditioned to yield a differentiation operator 

386 that moves to Chebychev U basis during differentiation, which is sparse. When using this technique, problems need to 

387 be formulated in first order formulation. 

388 

389 This implementation is largely based on the Dedalus paper (https://doi.org/10.1103/PhysRevResearch.2.023068). 

390 """ 

391 

392 def __init__(self, *args, x0=-1, x1=1, **kwargs): 

393 """ 

394 Constructor. 

395 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

396 may choose to run on GPUs via the `useGPU` argument. 

397 

398 Args: 

399 x0 (float): Coordinate of left boundary. Note that only -1 is currently implented 

400 x1 (float): Coordinate of right boundary. Note that only +1 is currently implented 

401 """ 

402 # need linear transformation y = ax + b with a = (x1-x0)/2 and b = (x1+x0)/2 

403 self.lin_trf_fac = (x1 - x0) / 2 

404 self.lin_trf_off = (x1 + x0) / 2 

405 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

406 

407 self.norm = self.get_norm() 

408 

409 def get_1dgrid(self): 

410 ''' 

411 Generates a 1D grid with Chebychev points. These are clustered at the boundary. You need this kind of grid to 

412 use discrete cosine transformation (DCT) to get the Chebychev representation. If you want a different grid, you 

413 need to do an affine transformation before any Chebychev business. 

414 

415 Returns: 

416 numpy.ndarray: 1D grid 

417 ''' 

418 return self.lin_trf_fac * self.xp.cos(np.pi / self.N * (self.xp.arange(self.N) + 0.5)) + self.lin_trf_off 

419 

420 def get_wavenumbers(self): 

421 """Get the domain in spectral space""" 

422 return self.xp.arange(self.N) 

423 

424 @cache 

425 def get_conv(self, name, N=None): 

426 ''' 

427 Get conversion matrix between different kinds of polynomials. The supported kinds are 

428 - T: Chebychev polynomials of first kind 

429 - U: Chebychev polynomials of second kind 

430 - D: Dirichlet recombination. 

431 

432 You get the desired matrix by choosing a name as ``A2B``. I.e. ``T2U`` for the conversion matrix from T to U. 

433 Once generates matrices are cached. So feel free to call the method as often as you like. 

434 

435 Args: 

436 name (str): Conversion code, e.g. 'T2U' 

437 N (int): Size of the matrix (optional) 

438 

439 Returns: 

440 scipy.sparse: Sparse conversion matrix 

441 ''' 

442 N = N if N else self.N 

443 sp = self.sparse_lib 

444 

445 def get_forward_conv(name): 

446 if name == 'T2U': 

447 mat = (sp.eye(N) - sp.eye(N, k=2)).tocsc() / 2.0 

448 mat[:, 0] *= 2 

449 elif name == 'D2T': 

450 mat = sp.eye(N) - sp.eye(N, k=2) 

451 elif name[0] == name[-1]: 

452 mat = self.sparse_lib.eye(self.N) 

453 else: 

454 raise NotImplementedError(f'Don\'t have conversion matrix {name!r}') 

455 return mat 

456 

457 try: 

458 mat = get_forward_conv(name) 

459 except NotImplementedError as E: 

460 try: 

461 fwd = get_forward_conv(name[::-1]) 

462 import scipy.sparse as sp 

463 

464 if self.sparse_lib == sp: 

465 mat = self.sparse_lib.linalg.inv(fwd.tocsc()) 

466 else: 

467 mat = self.sparse_lib.csc_matrix(sp.linalg.inv(fwd.tocsc().get())) 

468 except NotImplementedError: 

469 raise NotImplementedError from E 

470 

471 return mat 

472 

473 def get_basis_change_matrix(self, conv='T2T', **kwargs): 

474 """ 

475 As the differentiation matrix in Chebychev-T base is dense but is sparse when simultaneously changing base to 

476 Chebychev-U, you may need a basis change matrix to transfer the other matrices as well. This function returns a 

477 conversion matrix from `ChebychevHelper.get_conv`. Not that `**kwargs` are used to absorb arguments for other 

478 bases, see documentation of `SpectralHelper1D.get_basis_change_matrix`. 

479 

480 Args: 

481 conv (str): Conversion code, i.e. T2U 

482 

483 Returns: 

484 Sparse conversion matrix 

485 """ 

486 return self.get_conv(conv) 

487 

488 def get_integration_matrix(self, lbnd=0): 

489 """ 

490 Get matrix for integration 

491 

492 Args: 

493 lbnd (float): Lower bound for integration, only 0 is currently implemented 

494 

495 Returns: 

496 Sparse integration matrix 

497 """ 

498 S = self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) @ self.get_conv('T2U') 

499 n = self.xp.arange(self.N) 

500 if lbnd == 0: 

501 S = S.tocsc() 

502 S[0, 1::2] = ( 

503 (n / (2 * (self.xp.arange(self.N) + 1)))[1::2] 

504 * (-1) ** (self.xp.arange(self.N // 2)) 

505 / (np.append([1], self.xp.arange(self.N // 2 - 1) + 1)) 

506 ) * self.lin_trf_fac 

507 else: 

508 raise NotImplementedError(f'This function allows to integrate only from x=0, you attempted from x={lbnd}.') 

509 return S 

510 

511 def get_integration_weights(self): 

512 """Weights for integration across entire domain""" 

513 n = self.xp.arange(self.N, dtype=float) 

514 

515 weights = (-1) ** n + 1 

516 weights[2:] /= 1 - (n**2)[2:] 

517 

518 weights /= 2 / self.L 

519 return weights 

520 

521 def get_differentiation_matrix(self, p=1): 

522 ''' 

523 Keep in mind that the T2T differentiation matrix is dense. 

524 

525 Args: 

526 p (int): Derivative you want to compute 

527 

528 Returns: 

529 numpy.ndarray: Differentiation matrix 

530 ''' 

531 D = self.xp.zeros((self.N, self.N)) 

532 for j in range(self.N): 

533 for k in range(j): 

534 D[k, j] = 2 * j * ((j - k) % 2) 

535 

536 D[0, :] /= 2 

537 return self.sparse_lib.csc_matrix(self.xp.linalg.matrix_power(D, p)) / self.lin_trf_fac**p 

538 

539 @cache 

540 def get_norm(self, N=None): 

541 ''' 

542 Get normalization for converting Chebychev coefficients and DCT 

543 

544 Args: 

545 N (int, optional): Resolution 

546 

547 Returns: 

548 self.xp.array: Normalization 

549 ''' 

550 N = self.N if N is None else N 

551 norm = self.xp.ones(N) / N 

552 norm[0] /= 2 

553 return norm 

554 

555 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

556 """ 

557 DCT along axes. `kwargs` will be passed on to the FFT library. 

558 

559 Args: 

560 u: Data you want to transform 

561 axes (tuple): Axes you want to transform along 

562 

563 Returns: 

564 Data in spectral space 

565 """ 

566 axes = axes if axes else tuple(i for i in range(u.ndim)) 

567 kwargs['s'] = shape 

568 kwargs['norm'] = kwargs.get('norm', 'backward') 

569 

570 trf = self.fft_lib.dctn(u, *args, axes=axes, type=2, **kwargs) 

571 for axis in axes: 

572 

573 if self.N < trf.shape[axis]: 

574 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

575 # removed in the middle rather than the end. We need to resort this here and put the highest frequencies 

576 # in the middle. 

577 _trf = self.xp.zeros_like(trf) 

578 N = self.N 

579 N_pad = _trf.shape[axis] - N 

580 end_first_half = N // 2 + 1 

581 

582 # copy first "half" 

583 su = [slice(None)] * trf.ndim 

584 su[axis] = slice(0, end_first_half) 

585 _trf[tuple(su)] = trf[tuple(su)] 

586 

587 # copy second "half" 

588 su = [slice(None)] * u.ndim 

589 su[axis] = slice(end_first_half + N_pad, None) 

590 s_u = [slice(None)] * u.ndim 

591 s_u[axis] = slice(end_first_half, N) 

592 _trf[tuple(su)] = trf[tuple(s_u)] 

593 

594 # # copy values to be cut 

595 # su = [slice(None)] * u.ndim 

596 # su[axis] = slice(end_first_half, end_first_half + N_pad) 

597 # s_u = [slice(None)] * u.ndim 

598 # s_u[axis] = slice(-N_pad, None) 

599 # _trf[tuple(su)] = trf[tuple(s_u)] 

600 

601 trf = _trf 

602 

603 expansion = [np.newaxis for _ in u.shape] 

604 expansion[axis] = slice(0, u.shape[axis], 1) 

605 norm = self.xp.ones(trf.shape[axis]) * self.norm[-1] 

606 norm[: self.N] = self.norm 

607 trf *= norm[(*expansion,)] 

608 return trf 

609 

610 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

611 """ 

612 Inverse DCT along axis. 

613 

614 Args: 

615 u: Data you want to transform 

616 axes (tuple): Axes you want to transform along 

617 

618 Returns: 

619 Data in physical space 

620 """ 

621 axes = axes if axes else tuple(i for i in range(u.ndim)) 

622 kwargs['s'] = shape 

623 kwargs['norm'] = kwargs.get('norm', 'backward') 

624 kwargs['overwrite_x'] = kwargs.get('overwrite_x', False) 

625 

626 for axis in axes: 

627 

628 if self.N == u.shape[axis]: 

629 _u = u.copy() 

630 else: 

631 # mpi4py-fft implements padding only for FFT, where the frequencies are sorted such that the zeros are 

632 # added in the middle rather than the end. We need to resort this here and put the padding in the end. 

633 N = self.N 

634 _u = self.xp.zeros_like(u) 

635 

636 # copy first half 

637 su = [slice(None)] * u.ndim 

638 su[axis] = slice(0, N // 2 + 1) 

639 _u[tuple(su)] = u[tuple(su)] 

640 

641 # copy second half 

642 su = [slice(None)] * u.ndim 

643 su[axis] = slice(-(N // 2), None) 

644 s_u = [slice(None)] * u.ndim 

645 s_u[axis] = slice(N // 2, N // 2 + (N // 2)) 

646 _u[tuple(s_u)] = u[tuple(su)] 

647 

648 if N % 2 == 0: 

649 su = [slice(None)] * u.ndim 

650 su[axis] = N // 2 

651 _u[tuple(su)] *= 2 

652 

653 # generate norm 

654 expansion = [np.newaxis for _ in u.shape] 

655 expansion[axis] = slice(0, u.shape[axis], 1) 

656 norm = self.xp.ones(_u.shape[axis]) 

657 norm[: self.N] = self.norm 

658 norm = self.get_norm(u.shape[axis]) * _u.shape[axis] / self.N 

659 

660 _u /= norm[(*expansion,)] 

661 

662 return self.fft_lib.idctn(_u, *args, axes=axes, type=2, **kwargs) 

663 

664 def get_BC(self, kind, **kwargs): 

665 """ 

666 Get boundary condition row for boundary bordering. `kwargs` will be passed on to implementations of the BC of 

667 the kind you choose. Specifically, `x` for `'dirichlet'` boundary condition, which is the coordinate at which to 

668 set the BC. 

669 

670 Args: 

671 kind ('integral' or 'dirichlet'): Kind of boundary condition you want 

672 """ 

673 if kind.lower() == 'integral': 

674 return self.get_integ_BC_row(**kwargs) 

675 elif kind.lower() == 'dirichlet': 

676 return self.get_Dirichlet_BC_row(**kwargs) 

677 elif kind.lower() == 'neumann': 

678 return self.get_Neumann_BC_row(**kwargs) 

679 else: 

680 return super().get_BC(kind) 

681 

682 def get_integ_BC_row(self): 

683 """ 

684 Get a row for generating integral BCs with T polynomials. 

685 It returns the values of the integrals of T polynomials over the entire interval. 

686 

687 Returns: 

688 self.xp.ndarray: Row to put into a matrix 

689 """ 

690 n = self.xp.arange(self.N) + 1 

691 me = self.xp.zeros_like(n).astype(float) 

692 me[2:] = ((-1) ** n[1:-1] + 1) / (1 - n[1:-1] ** 2) 

693 me[0] = 2.0 

694 return me 

695 

696 def get_Dirichlet_BC_row(self, x): 

697 """ 

698 Get a row for generating Dirichlet BCs at x with T polynomials. 

699 It returns the values of the T polynomials at x. 

700 

701 Args: 

702 x (float): Position of the boundary condition 

703 

704 Returns: 

705 self.xp.ndarray: Row to put into a matrix 

706 """ 

707 if x == -1: 

708 return (-1) ** self.xp.arange(self.N) 

709 elif x == 1: 

710 return self.xp.ones(self.N) 

711 elif x == 0: 

712 n = (1 + (-1) ** self.xp.arange(self.N)) / 2 

713 n[2::4] *= -1 

714 return n 

715 else: 

716 raise NotImplementedError(f'Don\'t know how to generate Dirichlet BC\'s at {x=}!') 

717 

718 def get_Neumann_BC_row(self, x): 

719 """ 

720 Get a row for generating Neumann BCs at x with T polynomials. 

721 

722 Args: 

723 x (float): Position of the boundary condition 

724 

725 Returns: 

726 self.xp.ndarray: Row to put into a matrix 

727 """ 

728 n = self.xp.arange(self.N, dtype='D') 

729 nn = n**2 

730 if x == -1: 

731 me = nn 

732 me[1:] *= (-1) ** n[:-1] 

733 return me 

734 elif x == 1: 

735 return nn 

736 else: 

737 raise NotImplementedError(f'Don\'t know how to generate Neumann BC\'s at {x=}!') 

738 

739 def get_Dirichlet_recombination_matrix(self): 

740 ''' 

741 Get matrix for Dirichlet recombination, which changes the basis to have sparse boundary conditions. 

742 This makes for a good right preconditioner. 

743 

744 Returns: 

745 scipy.sparse: Sparse conversion matrix 

746 ''' 

747 N = self.N 

748 sp = self.sparse_lib 

749 

750 return sp.eye(N) - sp.eye(N, k=2) 

751 

752 

753class UltrasphericalHelper(ChebychevHelper): 

754 """ 

755 This implementation follows https://doi.org/10.1137/120865458. 

756 The ultraspherical method works in Chebychev polynomials as well, but also uses various Gegenbauer polynomials. 

757 The idea is that for every derivative of Chebychev T polynomials, there is a basis of Gegenbauer polynomials where the differentiation matrix is a single off-diagonal. 

758 There are also conversion operators from one derivative basis to the next that are sparse. 

759 

760 This basis is used like this: For every equation that you have, look for the highest derivative and bump all matrices to the correct basis. If your highest derivative is 2 and you have an identity, it needs to get bumped from 0 to 1 and from 1 to 2. If you have a first derivative as well, it needs to be bumped from 1 to 2. 

761 You don't need the same resulting basis in all equations. You just need to take care that you translate the right hand side to the correct basis as well. 

762 """ 

763 

764 def get_differentiation_matrix(self, p=1): 

765 """ 

766 Notice that while sparse, this matrix is not diagonal, which means the inversion cannot be parallelized easily. 

767 

768 Args: 

769 p (int): Order of the derivative 

770 

771 Returns: 

772 sparse differentiation matrix 

773 """ 

774 sp = self.sparse_lib 

775 xp = self.xp 

776 N = self.N 

777 l = p 

778 return 2 ** (l - 1) * factorial(l - 1) * sp.diags(xp.arange(N - l) + l, offsets=l) / self.lin_trf_fac**p 

779 

780 def get_S(self, lmbda): 

781 """ 

782 Get matrix for bumping the derivative base by one from lmbda to lmbda + 1. This is the same language as in 

783 https://doi.org/10.1137/120865458. 

784 

785 Args: 

786 lmbda (int): Ingoing derivative base 

787 

788 Returns: 

789 sparse matrix: Conversion from derivative base lmbda to lmbda + 1 

790 """ 

791 N = self.N 

792 

793 if lmbda == 0: 

794 sp = scipy.sparse 

795 mat = ((sp.eye(N) - sp.eye(N, k=2)) / 2.0).tolil() 

796 mat[:, 0] *= 2 

797 else: 

798 sp = self.sparse_lib 

799 xp = self.xp 

800 mat = sp.diags(lmbda / (lmbda + xp.arange(N))) - sp.diags( 

801 lmbda / (lmbda + 2 + xp.arange(N - 2)), offsets=+2 

802 ) 

803 

804 return self.sparse_lib.csc_matrix(mat) 

805 

806 def get_basis_change_matrix(self, p_in=0, p_out=0, **kwargs): 

807 """ 

808 Get a conversion matrix from derivative base `p_in` to `p_out`. 

809 

810 Args: 

811 p_out (int): Resulting derivative base 

812 p_in (int): Ingoing derivative base 

813 """ 

814 mat_fwd = self.sparse_lib.eye(self.N) 

815 for i in range(min([p_in, p_out]), max([p_in, p_out])): 

816 mat_fwd = self.get_S(i) @ mat_fwd 

817 

818 if p_out > p_in: 

819 return mat_fwd 

820 

821 else: 

822 # We have to invert the matrix on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

823 import scipy.sparse as sp 

824 

825 if self.useGPU: 

826 mat_fwd = mat_fwd.get() 

827 

828 mat_bck = sp.linalg.inv(mat_fwd.tocsc()) 

829 

830 return self.sparse_lib.csc_matrix(mat_bck) 

831 

832 def get_integration_matrix(self): 

833 """ 

834 Get an integration matrix. Please use `UltrasphericalHelper.get_integration_constant` afterwards to compute the 

835 integration constant such that integration starts from x=-1. 

836 

837 Example: 

838 

839 .. code-block:: python 

840 

841 import numpy as np 

842 from pySDC.helpers.spectral_helper import UltrasphericalHelper 

843 

844 N = 4 

845 helper = UltrasphericalHelper(N) 

846 coeffs = np.random.random(N) 

847 coeffs[-1] = 0 

848 

849 poly = np.polynomial.Chebyshev(coeffs) 

850 

851 S = helper.get_integration_matrix() 

852 U_hat = S @ coeffs 

853 U_hat[0] = helper.get_integration_constant(U_hat, axis=-1) 

854 

855 assert np.allclose(poly.integ(lbnd=-1).coef[:-1], U_hat) 

856 

857 Returns: 

858 sparse integration matrix 

859 """ 

860 return ( 

861 self.sparse_lib.diags(1 / (self.xp.arange(self.N - 1) + 1), offsets=-1) 

862 @ self.get_basis_change_matrix(p_out=1, p_in=0) 

863 * self.lin_trf_fac 

864 ) 

865 

866 def get_integration_constant(self, u_hat, axis): 

867 """ 

868 Get integration constant for lower bound of -1. See documentation of `UltrasphericalHelper.get_integration_matrix` for details. 

869 

870 Args: 

871 u_hat: Solution in spectral space 

872 axis: Axis you want to integrate over 

873 

874 Returns: 

875 Integration constant, has one less dimension than `u_hat` 

876 """ 

877 slices = [ 

878 None, 

879 ] * u_hat.ndim 

880 slices[axis] = slice(1, u_hat.shape[axis]) 

881 return self.xp.sum(u_hat[(*slices,)] * (-1) ** (self.xp.arange(u_hat.shape[axis] - 1)), axis=axis) 

882 

883 

884class FFTHelper(SpectralHelper1D): 

885 distributable = True 

886 

887 def __init__(self, *args, x0=0, x1=2 * np.pi, **kwargs): 

888 """ 

889 Constructor. 

890 Please refer to the parent class for additional arguments. Notably, you have to supply a resolution `N` and you 

891 may choose to run on GPUs via the `useGPU` argument. 

892 

893 Args: 

894 x0 (float, optional): Coordinate of left boundary 

895 x1 (float, optional): Coordinate of right boundary 

896 """ 

897 super().__init__(*args, x0=x0, x1=x1, **kwargs) 

898 

899 def get_1dgrid(self): 

900 """ 

901 We use equally spaced points including the left boundary and not including the right one, which is the left boundary. 

902 """ 

903 dx = self.L / self.N 

904 return self.xp.arange(self.N) * dx + self.x0 

905 

906 def get_wavenumbers(self): 

907 """ 

908 Be careful that this ordering is very unintuitive. 

909 """ 

910 return self.xp.fft.fftfreq(self.N, 1.0 / self.N) * 2 * np.pi / self.L 

911 

912 def get_differentiation_matrix(self, p=1): 

913 """ 

914 This matrix is diagonal, allowing to invert concurrently. 

915 

916 Args: 

917 p (int): Order of the derivative 

918 

919 Returns: 

920 sparse differentiation matrix 

921 """ 

922 k = self.get_wavenumbers() 

923 

924 if self.useGPU: 

925 if p > 1: 

926 # Have to raise the matrix to power p on CPU because the GPU equivalent is not implemented in CuPy at the time of writing. 

927 from scipy.sparse.linalg import matrix_power 

928 

929 D = self.sparse_lib.diags(1j * k).get() 

930 return self.sparse_lib.csc_matrix(matrix_power(D, p)) 

931 else: 

932 return self.sparse_lib.diags(1j * k) 

933 else: 

934 return self.linalg.matrix_power(self.sparse_lib.diags(1j * k), p) 

935 

936 def get_integration_matrix(self, p=1): 

937 """ 

938 Get integration matrix to compute `p`-th integral over the entire domain. 

939 

940 Args: 

941 p (int): Order of integral you want to compute 

942 

943 Returns: 

944 sparse integration matrix 

945 """ 

946 k = self.xp.array(self.get_wavenumbers(), dtype='complex128') 

947 k[0] = 1j * self.L 

948 return self.linalg.matrix_power(self.sparse_lib.diags(1 / (1j * k)), p) 

949 

950 def get_integration_weights(self): 

951 """Weights for integration across entire domain""" 

952 weights = self.xp.zeros(self.N) 

953 weights[0] = self.L / self.N 

954 return weights 

955 

956 def get_plan(self, u, forward, *args, **kwargs): 

957 if self.fft_lib.__name__ == 'mpi4py_fft.fftw': 

958 if 'axes' in kwargs.keys(): 

959 kwargs['axes'] = tuple(kwargs['axes']) 

960 key = (forward, u.shape, args, *(me for me in kwargs.values())) 

961 if key in self.plans.keys(): 

962 return self.plans[key] 

963 else: 

964 self.logger.debug(f'Generating FFT plan for {key=}') 

965 transform = self.fft_lib.fftn(u, *args, **kwargs) if forward else self.fft_lib.ifftn(u, *args, **kwargs) 

966 self.plans[key] = transform 

967 

968 return self.plans[key] 

969 else: 

970 if forward: 

971 return partial(self.fft_lib.fftn, norm=kwargs.get('norm', 'backward')) 

972 else: 

973 return partial(self.fft_lib.ifftn, norm=kwargs.get('norm', 'forward')) 

974 

975 def transform(self, u, *args, axes=None, shape=None, **kwargs): 

976 """ 

977 FFT along axes. `kwargs` are passed on to the FFT library. 

978 

979 Args: 

980 u: Data you want to transform 

981 axes (tuple): Axes you want to transform over 

982 

983 Returns: 

984 transformed data 

985 """ 

986 axes = axes if axes else tuple(i for i in range(u.ndim)) 

987 kwargs['s'] = shape 

988 plan = self.get_plan(u, *args, forward=True, axes=axes, **kwargs) 

989 return plan(u, *args, axes=axes, **kwargs) 

990 

991 def itransform(self, u, *args, axes=None, shape=None, **kwargs): 

992 """ 

993 Inverse FFT. 

994 

995 Args: 

996 u: Data you want to transform 

997 axes (tuple): Axes over which to transform 

998 

999 Returns: 

1000 transformed data 

1001 """ 

1002 axes = axes if axes else tuple(i for i in range(u.ndim)) 

1003 kwargs['s'] = shape 

1004 plan = self.get_plan(u, *args, forward=False, axes=axes, **kwargs) 

1005 return plan(u, *args, axes=axes, **kwargs) / np.prod([u.shape[axis] for axis in axes]) 

1006 

1007 def get_BC(self, kind): 

1008 """ 

1009 Get a sort of boundary condition. You can use `kind=integral`, to fix the integral, or you can use `kind=Nyquist`. 

1010 The latter is not really a boundary condition, but is used to set the Nyquist mode to some value, preferably zero. 

1011 You should set the Nyquist mode zero when the solution in physical space is real and the resolution is even. 

1012 

1013 Args: 

1014 kind ('integral' or 'nyquist'): Kind of BC 

1015 

1016 Returns: 

1017 self.xp.ndarray: Boundary condition row 

1018 """ 

1019 if kind.lower() == 'integral': 

1020 return self.get_integ_BC_row() 

1021 elif kind.lower() == 'nyquist': 

1022 assert ( 

1023 self.N % 2 == 0 

1024 ), f'Do not eliminate the Nyquist mode with odd resolution as it is fully resolved. You chose {self.N} in this axis' 

1025 BC = self.xp.zeros(self.N) 

1026 BC[self.get_Nyquist_mode_index()] = 1 

1027 return BC 

1028 else: 

1029 return super().get_BC(kind) 

1030 

1031 def get_Nyquist_mode_index(self): 

1032 """ 

1033 Compute the index of the Nyquist mode, i.e. the mode with the lowest wavenumber, which doesn't have a positive 

1034 counterpart for even resolution. This means real waves of this wave number cannot be properly resolved and you 

1035 are best advised to set this mode zero if representing real functions on even-resolution grids is what you're 

1036 after. 

1037 

1038 Returns: 

1039 int: Index of the Nyquist mode 

1040 """ 

1041 k = self.get_wavenumbers() 

1042 Nyquist_mode = min(k) 

1043 return self.xp.where(k == Nyquist_mode)[0][0] 

1044 

1045 def get_integ_BC_row(self): 

1046 """ 

1047 Only the 0-mode has non-zero integral with FFT basis in periodic BCs 

1048 """ 

1049 me = self.xp.zeros(self.N) 

1050 me[0] = self.L / self.N 

1051 return me 

1052 

1053 

1054class SpectralHelper: 

1055 """ 

1056 This class has three functions: 

1057 - Easily assemble matrices containing multiple equations 

1058 - Direct product of 1D bases to solve problems in more dimensions 

1059 - Distribute the FFTs to facilitate concurrency. 

1060 

1061 Attributes: 

1062 comm (mpi4py.Intracomm): MPI communicator 

1063 debug (bool): Perform additional checks at extra computational cost 

1064 useGPU (bool): Whether to use GPUs 

1065 axes (list): List of 1D bases 

1066 components (list): List of strings of the names of components in the equations 

1067 full_BCs (list): List of Dictionaries containing all information about the boundary conditions 

1068 BC_mat (list): List of lists of sparse matrices to put BCs into and eventually assemble the BC matrix from 

1069 BCs (sparse matrix): Matrix containing only the BCs 

1070 fft_cache (dict): Cache FFTs of various shapes here to facilitate padding and so on 

1071 BC_rhs_mask (self.xp.ndarray): Mask values that contain boundary conditions in the right hand side 

1072 BC_zero_index (self.xp.ndarray): Indeces of rows in the matrix that are replaced by BCs 

1073 BC_line_zero_matrix (sparse matrix): Matrix that zeros rows where we can then add the BCs in using `BCs` 

1074 rhs_BCs_hat (self.xp.ndarray): Boundary conditions in spectral space 

1075 global_shape (tuple): Global shape of the solution as in `mpi4py-fft` 

1076 fft_obj: When using distributed FFTs, this will be a parallel transform object from `mpi4py-fft` 

1077 init (tuple): This is the same `init` that is used throughout the problem classes 

1078 init_forward (tuple): This is the equivalent of `init` in spectral space 

1079 """ 

1080 

1081 xp = np 

1082 fft_lib = scipy.fft 

1083 sparse_lib = scipy.sparse 

1084 linalg = scipy.sparse.linalg 

1085 dtype = mesh 

1086 fft_backend = 'scipy' 

1087 fft_comm_backend = 'MPI' 

1088 

1089 @classmethod 

1090 def setup_GPU(cls): 

1091 """switch to GPU modules""" 

1092 import cupy as cp 

1093 import cupyx.scipy.sparse as sparse_lib 

1094 import cupyx.scipy.sparse.linalg as linalg 

1095 import cupyx.scipy.fft as fft_lib 

1096 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh 

1097 

1098 cls.xp = cp 

1099 cls.sparse_lib = sparse_lib 

1100 cls.linalg = linalg 

1101 

1102 cls.fft_lib = fft_lib 

1103 cls.fft_backend = 'cupyx-scipy' 

1104 cls.fft_comm_backend = 'NCCL' 

1105 

1106 cls.dtype = cupy_mesh 

1107 

1108 @classmethod 

1109 def setup_CPU(cls, useFFTW=False): 

1110 """switch to CPU modules""" 

1111 

1112 cls.xp = np 

1113 cls.sparse_lib = scipy.sparse 

1114 cls.linalg = scipy.sparse.linalg 

1115 

1116 if useFFTW: 

1117 from mpi4py_fft import fftw 

1118 

1119 cls.fft_backend = 'fftw' 

1120 cls.fft_lib = fftw 

1121 else: 

1122 cls.fft_backend = 'scipy' 

1123 cls.fft_lib = scipy.fft 

1124 

1125 cls.fft_comm_backend = 'MPI' 

1126 cls.dtype = mesh 

1127 

1128 def __init__(self, comm=None, useGPU=False, debug=False): 

1129 """ 

1130 Constructor 

1131 

1132 Args: 

1133 comm (mpi4py.Intracomm): MPI communicator 

1134 useGPU (bool): Whether to use GPUs 

1135 debug (bool): Perform additional checks at extra computational cost 

1136 """ 

1137 self.comm = comm 

1138 self.debug = debug 

1139 self.useGPU = useGPU 

1140 

1141 if useGPU: 

1142 self.setup_GPU() 

1143 else: 

1144 self.setup_CPU() 

1145 

1146 self.axes = [] 

1147 self.components = [] 

1148 

1149 self.full_BCs = [] 

1150 self.BC_mat = None 

1151 self.BCs = None 

1152 

1153 self.fft_cache = {} 

1154 

1155 self.logger = logging.getLogger(name='Spectral Discretization') 

1156 if debug: 

1157 self.logger.setLevel(logging.DEBUG) 

1158 

1159 @property 

1160 def u_init(self): 

1161 """ 

1162 Get empty data container in physical space 

1163 """ 

1164 return self.dtype(self.init) 

1165 

1166 @property 

1167 def u_init_forward(self): 

1168 """ 

1169 Get empty data container in spectral space 

1170 """ 

1171 return self.dtype(self.init_forward) 

1172 

1173 @property 

1174 def u_init_physical(self): 

1175 """ 

1176 Get empty data container in physical space 

1177 """ 

1178 return self.dtype(self.init_physical) 

1179 

1180 @property 

1181 def shape(self): 

1182 """ 

1183 Get shape of individual solution component 

1184 """ 

1185 return self.init[0][1:] 

1186 

1187 @property 

1188 def ndim(self): 

1189 return len(self.axes) 

1190 

1191 @property 

1192 def ncomponents(self): 

1193 return len(self.components) 

1194 

1195 @property 

1196 def V(self): 

1197 """ 

1198 Get domain volume 

1199 """ 

1200 return np.prod([me.L for me in self.axes]) 

1201 

1202 def add_axis(self, base, *args, **kwargs): 

1203 """ 

1204 Add an axis to the domain by deciding on suitable 1D base. 

1205 Arguments to the bases are forwarded using `*args` and `**kwargs`. Please refer to the documentation of the 1D 

1206 bases for possible arguments. 

1207 

1208 Args: 

1209 base (str): 1D spectral method 

1210 """ 

1211 kwargs['useGPU'] = self.useGPU 

1212 

1213 if base.lower() in ['chebychov', 'chebychev', 'cheby', 'chebychovhelper']: 

1214 self.axes.append(ChebychevHelper(*args, **kwargs)) 

1215 elif base.lower() in ['fft', 'fourier', 'ffthelper']: 

1216 self.axes.append(FFTHelper(*args, **kwargs)) 

1217 elif base.lower() in ['ultraspherical', 'gegenbauer']: 

1218 self.axes.append(UltrasphericalHelper(*args, **kwargs)) 

1219 else: 

1220 raise NotImplementedError(f'{base=!r} is not implemented!') 

1221 self.axes[-1].xp = self.xp 

1222 self.axes[-1].sparse_lib = self.sparse_lib 

1223 

1224 def add_component(self, name): 

1225 """ 

1226 Add solution component(s). 

1227 

1228 Args: 

1229 name (str or list of strings): Name(s) of component(s) 

1230 """ 

1231 if type(name) in [list, tuple]: 

1232 for me in name: 

1233 self.add_component(me) 

1234 elif type(name) in [str]: 

1235 if name in self.components: 

1236 raise Exception(f'{name=!r} is already added to this problem!') 

1237 self.components.append(name) 

1238 else: 

1239 raise NotImplementedError 

1240 

1241 def index(self, name): 

1242 """ 

1243 Get the index of component `name`. 

1244 

1245 Args: 

1246 name (str or list of strings): Name(s) of component(s) 

1247 

1248 Returns: 

1249 int: Index of the component 

1250 """ 

1251 if type(name) in [str, int]: 

1252 return self.components.index(name) 

1253 elif type(name) in [list, tuple]: 

1254 return (self.index(me) for me in name) 

1255 else: 

1256 raise NotImplementedError(f'Don\'t know how to compute index for {type(name)=}') 

1257 

1258 def get_empty_operator_matrix(self, diag=False): 

1259 """ 

1260 Return a matrix of operators to be filled with the connections between the solution components. 

1261 

1262 Args: 

1263 diag (bool): Whether operator is block-diagonal 

1264 

1265 Returns: 

1266 list containing sparse zeros 

1267 """ 

1268 S = len(self.components) 

1269 O = self.get_Id() * 0 

1270 if diag: 

1271 return [O for _ in range(S)] 

1272 else: 

1273 return [[O for _ in range(S)] for _ in range(S)] 

1274 

1275 def get_BC(self, axis, kind, line=-1, scalar=False, **kwargs): 

1276 """ 

1277 Use this method for boundary bordering. It gets the respective matrix row and embeds it into a matrix. 

1278 Pay attention that if you have multiple BCs in a single equation, you need to put them in different lines. 

1279 Typically, the last line that does not contain a BC is the best choice. 

1280 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1281 

1282 Args: 

1283 axis (int): Axis you want to add the BC to 

1284 kind (str): kind of BC, e.g. Dirichlet 

1285 line (int): Line you want the BC to go in 

1286 scalar (bool): Put the BC in all space positions in the other direction 

1287 

1288 Returns: 

1289 sparse matrix containing the BC 

1290 """ 

1291 sp = scipy.sparse 

1292 

1293 base = self.axes[axis] 

1294 

1295 BC = sp.eye(base.N).tolil() * 0 

1296 if self.useGPU: 

1297 BC[line, :] = base.get_BC(kind=kind, **kwargs).get() 

1298 else: 

1299 BC[line, :] = base.get_BC(kind=kind, **kwargs) 

1300 

1301 ndim = len(self.axes) 

1302 if ndim == 1: 

1303 mat = self.sparse_lib.csc_matrix(BC) 

1304 elif ndim == 2: 

1305 axis2 = (axis + 1) % ndim 

1306 

1307 if scalar: 

1308 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[axis2].N - 1))) 

1309 else: 

1310 _Id = self.axes[axis2].get_Id() 

1311 

1312 Id = self.get_local_slice_of_1D_matrix(self.axes[axis2].get_Id() @ _Id, axis=axis2) 

1313 

1314 mats = [ 

1315 None, 

1316 ] * ndim 

1317 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1318 mats[axis2] = Id 

1319 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(*mats)) 

1320 elif ndim == 3: 

1321 mats = [ 

1322 None, 

1323 ] * ndim 

1324 

1325 for ax in range(ndim): 

1326 if ax == axis: 

1327 continue 

1328 

1329 if scalar: 

1330 _Id = self.sparse_lib.diags(self.xp.append([1], self.xp.zeros(self.axes[ax].N - 1))) 

1331 else: 

1332 _Id = self.axes[ax].get_Id() 

1333 

1334 mats[ax] = self.get_local_slice_of_1D_matrix(self.axes[ax].get_Id() @ _Id, axis=ax) 

1335 

1336 mats[axis] = self.get_local_slice_of_1D_matrix(BC, axis=axis) 

1337 

1338 mat = self.sparse_lib.csc_matrix(self.sparse_lib.kron(mats[0], self.sparse_lib.kron(*mats[1:]))) 

1339 else: 

1340 raise NotImplementedError( 

1341 f'Matrix expansion for boundary conditions not implemented for {ndim} dimensions!' 

1342 ) 

1343 mat = self.eliminate_zeros(mat) 

1344 return mat 

1345 

1346 def remove_BC(self, component, equation, axis, kind, line=-1, scalar=False, **kwargs): 

1347 """ 

1348 Remove a BC from the matrix. This is useful e.g. when you add a non-scalar BC and then need to selectively 

1349 remove single BCs again, as in incompressible Navier-Stokes, for instance. 

1350 Forwards arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1351 

1352 Args: 

1353 component (str): Name of the component the BC should act on 

1354 equation (str): Name of the equation for the component you want to put the BC in 

1355 axis (int): Axis you want to add the BC to 

1356 kind (str): kind of BC, e.g. Dirichlet 

1357 v: Value of the BC 

1358 line (int): Line you want the BC to go in 

1359 scalar (bool): Put the BC in all space positions in the other direction 

1360 """ 

1361 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1362 _BC = self.eliminate_zeros(_BC) 

1363 self.BC_mat[self.index(equation)][self.index(component)] -= _BC 

1364 

1365 if scalar: 

1366 slices = [self.index(equation)] + [ 

1367 0, 

1368 ] * self.ndim 

1369 slices[axis + 1] = line 

1370 else: 

1371 slices = ( 

1372 [self.index(equation)] 

1373 + [slice(0, self.init[0][i + 1]) for i in range(axis)] 

1374 + [line] 

1375 + [slice(0, self.init[0][i + 1]) for i in range(axis + 1, len(self.axes))] 

1376 ) 

1377 N = self.axes[axis].N 

1378 if (N + line) % N in self.xp.arange(N)[self.local_slice()[axis]]: 

1379 self.BC_rhs_mask[(*slices,)] = False 

1380 

1381 def add_BC(self, component, equation, axis, kind, v, line=-1, scalar=False, **kwargs): 

1382 """ 

1383 Add a BC to the matrix. Note that you need to convert the list of lists of BCs that this method generates to a 

1384 single sparse matrix by calling `setup_BCs` after adding/removing all BCs. 

1385 Forward arguments for the boundary conditions using `kwargs`. Refer to documentation of 1D bases for details. 

1386 

1387 Args: 

1388 component (str): Name of the component the BC should act on 

1389 equation (str): Name of the equation for the component you want to put the BC in 

1390 axis (int): Axis you want to add the BC to 

1391 kind (str): kind of BC, e.g. Dirichlet 

1392 v: Value of the BC 

1393 line (int): Line you want the BC to go in 

1394 scalar (bool): Put the BC in all space positions in the other direction 

1395 """ 

1396 _BC = self.get_BC(axis=axis, kind=kind, line=line, scalar=scalar, **kwargs) 

1397 self.BC_mat[self.index(equation)][self.index(component)] += _BC 

1398 self.full_BCs += [ 

1399 { 

1400 'component': component, 

1401 'equation': equation, 

1402 'axis': axis, 

1403 'kind': kind, 

1404 'v': v, 

1405 'line': line, 

1406 'scalar': scalar, 

1407 **kwargs, 

1408 } 

1409 ] 

1410 

1411 if scalar: 

1412 slices = [self.index(equation)] + [ 

1413 0, 

1414 ] * self.ndim 

1415 slices[axis + 1] = line 

1416 if self.comm: 

1417 if self.comm.rank == 0: 

1418 self.BC_rhs_mask[(*slices,)] = True 

1419 else: 

1420 self.BC_rhs_mask[(*slices,)] = True 

1421 else: 

1422 slices = [self.index(equation), *self.global_slice(True)] 

1423 N = self.axes[axis].N 

1424 if (N + line) % N in self.get_indices(True)[axis]: 

1425 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1426 self.BC_rhs_mask[(*slices,)] = True 

1427 

1428 def setup_BCs(self): 

1429 """ 

1430 Convert the list of lists of BCs to the boundary condition operator. 

1431 Also, boundary bordering requires to zero out all other entries in the matrix in rows containing a boundary 

1432 condition. This method sets up a suitable sparse matrix to do this. 

1433 """ 

1434 sp = self.sparse_lib 

1435 self.BCs = self.convert_operator_matrix_to_operator(self.BC_mat) 

1436 self.BC_zero_index = self.xp.arange(np.prod(self.init[0]))[self.BC_rhs_mask.flatten()] 

1437 

1438 diags = self.xp.ones(self.BCs.shape[0]) 

1439 diags[self.BC_zero_index] = 0 

1440 self.BC_line_zero_matrix = sp.diags(diags).tocsc() 

1441 

1442 # prepare BCs in spectral space to easily add to the RHS 

1443 rhs_BCs = self.put_BCs_in_rhs(self.u_init) 

1444 self.rhs_BCs_hat = self.transform(rhs_BCs).view(self.xp.ndarray) 

1445 del self.BC_rhs_mask 

1446 

1447 def check_BCs(self, u): 

1448 """ 

1449 Check that the solution satisfies the boundary conditions 

1450 

1451 Args: 

1452 u: The solution you want to check 

1453 """ 

1454 assert self.ndim < 3 

1455 for axis in range(self.ndim): 

1456 BCs = [me for me in self.full_BCs if me["axis"] == axis and not me["scalar"]] 

1457 

1458 if len(BCs) > 0: 

1459 u_hat = self.transform(u, axes=(axis - self.ndim,)) 

1460 for BC in BCs: 

1461 kwargs = { 

1462 key: value 

1463 for key, value in BC.items() 

1464 if key not in ['component', 'equation', 'axis', 'v', 'line', 'scalar'] 

1465 } 

1466 

1467 if axis == 0: 

1468 get = self.axes[axis].get_BC(**kwargs) @ u_hat[self.index(BC['component'])] 

1469 elif axis == 1: 

1470 get = u_hat[self.index(BC['component'])] @ self.axes[axis].get_BC(**kwargs) 

1471 want = BC['v'] 

1472 assert self.xp.allclose( 

1473 get, want 

1474 ), f'Unexpected BC in {BC["component"]} in equation {BC["equation"]}, line {BC["line"]}! Got {get}, wanted {want}' 

1475 

1476 def put_BCs_in_matrix(self, A): 

1477 """ 

1478 Put the boundary conditions in a matrix by replacing rows with BCs. 

1479 """ 

1480 return self.BC_line_zero_matrix @ A + self.BCs 

1481 

1482 def put_BCs_in_rhs_hat(self, rhs_hat): 

1483 """ 

1484 Put the BCs in the right hand side in spectral space for solving. 

1485 This function needs no transforms and caches a mask for faster subsequent use. 

1486 

1487 Args: 

1488 rhs_hat: Right hand side in spectral space 

1489 

1490 Returns: 

1491 rhs in spectral space with BCs 

1492 """ 

1493 if not hasattr(self, '_rhs_hat_zero_mask'): 

1494 """ 

1495 Generate a mask where we need to set values in the rhs in spectral space to zero, such that can replace them 

1496 by the boundary conditions. The mask is then cached. 

1497 """ 

1498 self._rhs_hat_zero_mask = self.newDistArray(forward_output=True).astype(bool).view(self.xp.ndarray) 

1499 

1500 for axis in range(self.ndim): 

1501 for bc in self.full_BCs: 

1502 if axis == bc['axis']: 

1503 slices = [self.index(bc['equation']), *self.global_slice(True)] 

1504 N = self.axes[axis].N 

1505 line = bc['line'] 

1506 if (N + line) % N in self.get_indices(True)[axis]: 

1507 slices[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1508 self._rhs_hat_zero_mask[(*slices,)] = True 

1509 

1510 rhs_hat[self._rhs_hat_zero_mask] = 0 

1511 return rhs_hat + self.rhs_BCs_hat 

1512 

1513 def put_BCs_in_rhs(self, rhs): 

1514 """ 

1515 Put the BCs in the right hand side for solving. 

1516 This function will transform along each axis individually and add all BCs in that axis. 

1517 Consider `put_BCs_in_rhs_hat` to add BCs with no extra transforms needed. 

1518 

1519 Args: 

1520 rhs: Right hand side in physical space 

1521 

1522 Returns: 

1523 rhs in physical space with BCs 

1524 """ 

1525 assert rhs.ndim > 1, 'rhs must not be flattened here!' 

1526 

1527 ndim = self.ndim 

1528 

1529 for axis in range(ndim): 

1530 _rhs_hat = self.transform(rhs, axes=(axis - ndim,)) 

1531 

1532 for bc in self.full_BCs: 

1533 

1534 if axis == bc['axis']: 

1535 _slice = [self.index(bc['equation']), *self.global_slice(True)] 

1536 

1537 N = self.axes[axis].N 

1538 line = bc['line'] 

1539 if (N + line) % N in self.get_indices(True)[axis]: 

1540 _slice[axis + 1] = (N + line) % N - self.local_slice()[axis].start 

1541 _rhs_hat[(*_slice,)] = bc['v'] 

1542 

1543 rhs = self.itransform(_rhs_hat, axes=(axis - ndim,)) 

1544 

1545 return rhs 

1546 

1547 def add_equation_lhs(self, A, equation, relations): 

1548 """ 

1549 Add the left hand part (that you want to solve implicitly) of an equation to a list of lists of sparse matrices 

1550 that you will convert to an operator later. 

1551 

1552 Example: 

1553 Setup linear operator `L` for 1D heat equation using Chebychev method in first order form and T-to-U 

1554 preconditioning: 

1555 

1556 .. code-block:: python 

1557 helper = SpectralHelper() 

1558 

1559 helper.add_axis(base='chebychev', N=8) 

1560 helper.add_component(['u', 'ux']) 

1561 helper.setup_fft() 

1562 

1563 I = helper.get_Id() 

1564 Dx = helper.get_differentiation_matrix(axes=(0,)) 

1565 T2U = helper.get_basis_change_matrix('T2U') 

1566 

1567 L_lhs = { 

1568 'ux': {'u': -T2U @ Dx, 'ux': T2U @ I}, 

1569 'u': {'ux': -(T2U @ Dx)}, 

1570 } 

1571 

1572 operator = helper.get_empty_operator_matrix() 

1573 for line, equation in L_lhs.items(): 

1574 helper.add_equation_lhs(operator, line, equation) 

1575 

1576 L = helper.convert_operator_matrix_to_operator(operator) 

1577 

1578 Args: 

1579 A (list of lists of sparse matrices): The operator to be 

1580 equation (str): The equation of the component you want this in 

1581 relations: (dict): Relations between quantities 

1582 """ 

1583 for k, v in relations.items(): 

1584 A[self.index(equation)][self.index(k)] = v 

1585 

1586 def eliminate_zeros(self, A): 

1587 """ 

1588 Eliminate zeros from sparse matrix. This can reduce memory footprint of matrices somewhat. 

1589 Note: At the time of writing, there are memory problems in the cupy implementation of `eliminate_zeros`. 

1590 Therefore, this function copies the matrix to host, eliminates the zeros there and then copies back to GPU. 

1591 

1592 Args: 

1593 A: sparse matrix to be pruned 

1594 

1595 Returns: 

1596 CSC sparse matrix 

1597 """ 

1598 if self.useGPU: 

1599 A = A.get() 

1600 A = A.tocsc() 

1601 A.eliminate_zeros() 

1602 if self.useGPU: 

1603 A = self.sparse_lib.csc_matrix(A) 

1604 return A 

1605 

1606 def convert_operator_matrix_to_operator(self, M): 

1607 """ 

1608 Promote the list of lists of sparse matrices to a single sparse matrix that can be used as linear operator. 

1609 See documentation of `SpectralHelper.add_equation_lhs` for an example. 

1610 

1611 Args: 

1612 M (list of lists of sparse matrices): The operator to be 

1613 

1614 Returns: 

1615 sparse linear operator 

1616 """ 

1617 if len(self.components) == 1: 

1618 op = M[0][0] 

1619 else: 

1620 op = self.sparse_lib.bmat(M, format='csc') 

1621 

1622 op = self.eliminate_zeros(op) 

1623 return op 

1624 

1625 def get_wavenumbers(self): 

1626 """ 

1627 Get grid in spectral space 

1628 """ 

1629 grids = [self.axes[i].get_wavenumbers()[self.local_slice(True)[i]] for i in range(len(self.axes))] 

1630 return self.xp.meshgrid(*grids, indexing='ij') 

1631 

1632 def get_grid(self, forward_output=False): 

1633 """ 

1634 Get grid in physical space 

1635 """ 

1636 grids = [self.axes[i].get_1dgrid()[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1637 return self.xp.meshgrid(*grids, indexing='ij') 

1638 

1639 def get_indices(self, forward_output=True): 

1640 return [self.xp.arange(self.axes[i].N)[self.local_slice(forward_output)[i]] for i in range(len(self.axes))] 

1641 

1642 @cache 

1643 def get_pfft(self, axes=None, padding=None, grid=None): 

1644 if self.ndim == 1 or self.comm is None: 

1645 return None 

1646 from mpi4py_fft import PFFT, newDistArray 

1647 

1648 axes = tuple(i for i in range(self.ndim)) if axes is None else axes 

1649 padding = list(padding if padding else [1.0 for _ in range(self.ndim)]) 

1650 

1651 def no_transform(u, *args, **kwargs): 

1652 return u 

1653 

1654 transforms = {(i,): (no_transform, no_transform) for i in range(self.ndim)} 

1655 for i in axes: 

1656 transforms[((i + self.ndim) % self.ndim,)] = (self.axes[i].transform, self.axes[i].itransform) 

1657 

1658 # "transform" all axes to ensure consistent shapes. 

1659 # Transform non-distributable axes last to ensure they are aligned 

1660 _axes = tuple(sorted((axis + self.ndim) % self.ndim for axis in axes)) 

1661 _axes = [axis for axis in _axes if not self.axes[axis].distributable] + sorted( 

1662 [axis for axis in _axes if self.axes[axis].distributable] 

1663 + [axis for axis in range(self.ndim) if axis not in _axes] 

1664 ) 

1665 

1666 pfft = PFFT( 

1667 comm=self.comm, 

1668 shape=self.global_shape[1:], 

1669 axes=_axes, # TODO: control the order of the transforms better 

1670 dtype='D', 

1671 collapse=False, 

1672 backend=self.fft_backend, 

1673 comm_backend=self.fft_comm_backend, 

1674 padding=padding, 

1675 transforms=transforms, 

1676 grid=grid, 

1677 ) 

1678 

1679 # do a transform to do the planning 

1680 _u = newDistArray(pfft, forward_output=False) 

1681 pfft.backward(pfft.forward(_u)) 

1682 return pfft 

1683 

1684 def get_fft(self, axes=None, direction='object', padding=None, shape=None): 

1685 """ 

1686 When using MPI, we use `PFFT` objects generated by mpi4py-fft 

1687 

1688 Args: 

1689 axes (tuple): Axes you want to transform over 

1690 direction (str): use "forward" or "backward" to get functions for performing the transforms or "object" to get the PFFT object 

1691 padding (tuple): Padding for dealiasing 

1692 shape (tuple): Shape of the transform 

1693 

1694 Returns: 

1695 transform 

1696 """ 

1697 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

1698 shape = self.global_shape[1:] if shape is None else shape 

1699 padding = ( 

1700 [ 

1701 1, 

1702 ] 

1703 * self.ndim 

1704 if padding is None 

1705 else padding 

1706 ) 

1707 key = (axes, direction, tuple(padding), tuple(shape)) 

1708 

1709 if key not in self.fft_cache.keys(): 

1710 if self.comm is None: 

1711 assert np.allclose(padding, 1), 'Zero padding is not implemented for non-MPI transforms' 

1712 

1713 if direction == 'forward': 

1714 self.fft_cache[key] = self.xp.fft.fftn 

1715 elif direction == 'backward': 

1716 self.fft_cache[key] = self.xp.fft.ifftn 

1717 elif direction == 'object': 

1718 self.fft_cache[key] = None 

1719 else: 

1720 if direction == 'object': 

1721 from mpi4py_fft import PFFT 

1722 

1723 _fft = PFFT( 

1724 comm=self.comm, 

1725 shape=shape, 

1726 axes=sorted(axes), 

1727 dtype='D', 

1728 collapse=False, 

1729 backend=self.fft_backend, 

1730 comm_backend=self.fft_comm_backend, 

1731 padding=padding, 

1732 ) 

1733 else: 

1734 _fft = self.get_fft(axes=axes, direction='object', padding=padding, shape=shape) 

1735 

1736 if direction == 'forward': 

1737 self.fft_cache[key] = _fft.forward 

1738 elif direction == 'backward': 

1739 self.fft_cache[key] = _fft.backward 

1740 elif direction == 'object': 

1741 self.fft_cache[key] = _fft 

1742 

1743 return self.fft_cache[key] 

1744 

1745 def local_slice(self, forward_output=True): 

1746 if self.fft_obj: 

1747 return self.get_pfft().local_slice(forward_output=forward_output) 

1748 else: 

1749 return [slice(0, me.N) for me in self.axes] 

1750 

1751 def global_slice(self, forward_output=True): 

1752 if self.fft_obj: 

1753 return [slice(0, me) for me in self.fft_obj.global_shape(forward_output=forward_output)] 

1754 else: 

1755 return self.local_slice(forward_output=forward_output) 

1756 

1757 def setup_fft(self, real_spectral_coefficients=False): 

1758 """ 

1759 This function must be called after all axes have been setup in order to prepare the local shapes of the data. 

1760 This must also be called before setting up any BCs. 

1761 

1762 Args: 

1763 real_spectral_coefficients (bool): Allow only real coefficients in spectral space 

1764 """ 

1765 if len(self.components) == 0: 

1766 self.add_component('u') 

1767 

1768 self.global_shape = (len(self.components),) + tuple(me.N for me in self.axes) 

1769 

1770 axes = tuple(i for i in range(len(self.axes))) 

1771 self.fft_obj = self.get_pfft(axes=axes) 

1772 

1773 self.init = ( 

1774 np.empty(shape=self.global_shape)[ 

1775 ( 

1776 ..., 

1777 *self.local_slice(False), 

1778 ) 

1779 ].shape, 

1780 self.comm, 

1781 np.dtype('float'), 

1782 ) 

1783 self.init_physical = ( 

1784 np.empty(shape=self.global_shape)[ 

1785 ( 

1786 ..., 

1787 *self.local_slice(False), 

1788 ) 

1789 ].shape, 

1790 self.comm, 

1791 np.dtype('float'), 

1792 ) 

1793 self.init_forward = ( 

1794 np.empty(shape=self.global_shape)[ 

1795 ( 

1796 ..., 

1797 *self.local_slice(True), 

1798 ) 

1799 ].shape, 

1800 self.comm, 

1801 np.dtype('float') if real_spectral_coefficients else np.dtype('complex128'), 

1802 ) 

1803 

1804 self.BC_mat = self.get_empty_operator_matrix() 

1805 self.BC_rhs_mask = self.newDistArray().astype(bool) 

1806 

1807 def newDistArray(self, pfft=None, forward_output=True, val=0, rank=1, view=False): 

1808 """ 

1809 Get an empty distributed array. This is almost a copy of the function of the same name from mpi4py-fft, but 

1810 takes care of all the solution components in the tensor. 

1811 """ 

1812 if self.comm is None: 

1813 return self.xp.zeros(self.init[0], dtype=self.init[2]) 

1814 from mpi4py_fft.distarray import DistArray 

1815 

1816 pfft = pfft if pfft else self.get_pfft() 

1817 if pfft is None: 

1818 if forward_output: 

1819 return self.u_init_forward 

1820 else: 

1821 return self.u_init 

1822 

1823 global_shape = pfft.global_shape(forward_output) 

1824 p0 = pfft.pencil[forward_output] 

1825 if forward_output is True: 

1826 dtype = pfft.forward.output_array.dtype 

1827 else: 

1828 dtype = pfft.forward.input_array.dtype 

1829 global_shape = (self.ncomponents,) * rank + global_shape 

1830 

1831 if pfft.xfftn[0].backend in ["cupy", "cupyx-scipy"]: 

1832 from mpi4py_fft.distarrayCuPy import DistArrayCuPy as darraycls 

1833 else: 

1834 darraycls = DistArray 

1835 

1836 z = darraycls(global_shape, subcomm=p0.subcomm, val=val, dtype=dtype, alignment=p0.axis, rank=rank) 

1837 return z.v if view else z 

1838 

1839 def infer_alignment(self, u, forward_output, padding=None, **kwargs): 

1840 if self.comm is None: 

1841 return [0] 

1842 

1843 def _alignment(pfft): 

1844 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1845 _aligned_axes = [i for i in range(self.ndim) if _arr.global_shape[i + 1] == u.shape[i + 1]] 

1846 return _aligned_axes 

1847 

1848 if padding is None: 

1849 pfft = self.get_pfft(**kwargs) 

1850 aligned_axes = _alignment(pfft) 

1851 else: 

1852 if self.ndim == 2: 

1853 padding_options = [(1.0, padding[1]), (padding[0], 1.0), padding, (1.0, 1.0)] 

1854 elif self.ndim == 3: 

1855 padding_options = [ 

1856 (1.0, 1.0, padding[2]), 

1857 (1.0, padding[1], 1.0), 

1858 (padding[0], 1.0, 1.0), 

1859 (1.0, padding[1], padding[2]), 

1860 (padding[0], 1.0, padding[2]), 

1861 (padding[0], padding[1], 1.0), 

1862 padding, 

1863 (1.0, 1.0, 1.0), 

1864 ] 

1865 else: 

1866 raise NotImplementedError(f'Don\'t know how to infer alignment in {self.ndim}D!') 

1867 for _padding in padding_options: 

1868 pfft = self.get_pfft(padding=_padding, **kwargs) 

1869 aligned_axes = _alignment(pfft) 

1870 if len(aligned_axes) > 0: 

1871 self.logger.debug( 

1872 f'Found alignment of array with size {u.shape}: {aligned_axes} using padding {_padding}' 

1873 ) 

1874 break 

1875 

1876 assert len(aligned_axes) > 0, f'Found no aligned axes for array of size {u.shape}!' 

1877 return aligned_axes 

1878 

1879 def redistribute(self, u, axis, forward_output, **kwargs): 

1880 if self.comm is None: 

1881 return u 

1882 

1883 pfft = self.get_pfft(**kwargs) 

1884 _arr = self.newDistArray(pfft, forward_output=forward_output) 

1885 

1886 if 'Dist' in type(u).__name__ and False: 

1887 try: 

1888 u.redistribute(out=_arr) 

1889 return _arr 

1890 except AssertionError: 

1891 pass 

1892 

1893 u_alignment = self.infer_alignment(u, forward_output=False, **kwargs) 

1894 for alignment in u_alignment: 

1895 _arr = _arr.redistribute(alignment) 

1896 if _arr.shape == u.shape: 

1897 _arr[...] = u 

1898 return _arr.redistribute(axis) 

1899 

1900 raise Exception( 

1901 f'Don\'t know how to align array of local shape {u.shape} and global shape {self.global_shape}, aligned in axes {u_alignment}, to axis {axis}' 

1902 ) 

1903 

1904 def transform(self, u, *args, axes=None, padding=None, **kwargs): 

1905 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1906 

1907 if pfft is None: 

1908 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1909 u_hat = u.copy() 

1910 for i in axes: 

1911 _axis = 1 + i if i >= 0 else i 

1912 u_hat = self.axes[i].transform(u_hat, axes=(_axis,)) 

1913 return u_hat 

1914 

1915 _in = self.newDistArray(pfft, forward_output=False, rank=1) 

1916 _out = self.newDistArray(pfft, forward_output=True, rank=1) 

1917 

1918 if _in.shape == u.shape: 

1919 _in[...] = u 

1920 else: 

1921 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=False, padding=padding, **kwargs) 

1922 

1923 for i in range(self.ncomponents): 

1924 pfft.forward(_in[i], _out[i], normalize=False) 

1925 

1926 if padding is not None: 

1927 _out /= np.prod(padding) 

1928 return _out 

1929 

1930 def itransform(self, u, *args, axes=None, padding=None, **kwargs): 

1931 if padding is not None: 

1932 assert all( 

1933 (self.axes[i].N * padding[i]) % 1 == 0 for i in range(self.ndim) 

1934 ), 'Cannot do this padding with this resolution. Resulting resolution must be integer' 

1935 

1936 pfft = self.get_pfft(*args, axes=axes, padding=padding, **kwargs) 

1937 if pfft is None: 

1938 axes = axes if axes else tuple(i for i in range(self.ndim)) 

1939 u_hat = u.copy() 

1940 for i in axes: 

1941 _axis = 1 + i if i >= 0 else i 

1942 u_hat = self.axes[i].itransform(u_hat, axes=(_axis,)) 

1943 return u_hat 

1944 

1945 _in = self.newDistArray(pfft, forward_output=True, rank=1) 

1946 _out = self.newDistArray(pfft, forward_output=False, rank=1) 

1947 

1948 if _in.shape == u.shape: 

1949 _in[...] = u 

1950 else: 

1951 _in[...] = self.redistribute(u, axis=_in.alignment, forward_output=True, padding=padding, **kwargs) 

1952 

1953 for i in range(self.ncomponents): 

1954 pfft.backward(_in[i], _out[i], normalize=True) 

1955 

1956 if padding is not None: 

1957 _out *= np.prod(padding) 

1958 return _out 

1959 

1960 def get_local_slice_of_1D_matrix(self, M, axis): 

1961 """ 

1962 Get the local version of a 1D matrix. When using distributed FFTs, each rank will carry only a subset of modes, 

1963 which you can sort out via the `SpectralHelper.local_slice()` attribute. When constructing a 1D matrix, you can 

1964 use this method to get the part corresponding to the modes carried by this rank. 

1965 

1966 Args: 

1967 M (sparse matrix): Global 1D matrix you want to get the local version of 

1968 axis (int): Direction in which you want the local version. You will get the global matrix in other directions. 

1969 

1970 Returns: 

1971 sparse local matrix 

1972 """ 

1973 return M.tocsc()[self.local_slice(True)[axis], self.local_slice(True)[axis]] 

1974 

1975 def expand_matrix_ND(self, matrix, aligned): 

1976 sp = self.sparse_lib 

1977 axes = np.delete(np.arange(self.ndim), aligned) 

1978 ndim = len(axes) + 1 

1979 

1980 if ndim == 1: 

1981 mat = matrix 

1982 elif ndim == 2: 

1983 axis = axes[0] 

1984 I1D = sp.eye(self.axes[axis].N) 

1985 

1986 mats = [None] * ndim 

1987 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1988 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1989 

1990 mat = sp.kron(*mats) 

1991 elif ndim == 3: 

1992 

1993 mats = [None] * ndim 

1994 mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) 

1995 for axis in axes: 

1996 I1D = sp.eye(self.axes[axis].N) 

1997 mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) 

1998 

1999 mat = sp.kron(mats[0], sp.kron(*mats[1:])) 

2000 

2001 else: 

2002 raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') 

2003 

2004 mat = self.eliminate_zeros(mat) 

2005 return mat 

2006 

2007 def get_filter_matrix(self, axis, **kwargs): 

2008 """ 

2009 Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are 

2010 admissible. 

2011 

2012 Returns: 

2013 sparse bandpass matrix 

2014 """ 

2015 if self.ndim == 1: 

2016 return self.axes[0].get_filter_matrix(**kwargs) 

2017 

2018 mats = [base.get_Id() for base in self.axes] 

2019 mats[axis] = self.axes[axis].get_filter_matrix(**kwargs) 

2020 return self.sparse_lib.kron(*mats) 

2021 

2022 def get_differentiation_matrix(self, axes, **kwargs): 

2023 """ 

2024 Get differentiation matrix along specified axis. `kwargs` are forwarded to the 1D base implementation. 

2025 

2026 Args: 

2027 axes (tuple): Axes along which to differentiate. 

2028 

2029 Returns: 

2030 sparse differentiation matrix 

2031 """ 

2032 D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) 

2033 for axis in axes[1:]: 

2034 _D = self.axes[axis].get_differentiation_matrix(**kwargs) 

2035 D = D @ self.expand_matrix_ND(_D, axis) 

2036 

2037 self.logger.debug(f'Set up differentiation matrix along axes {axes} with kwargs {kwargs}') 

2038 return D 

2039 

2040 def get_integration_matrix(self, axes): 

2041 """ 

2042 Get integration matrix to integrate along specified axis. 

2043 

2044 Args: 

2045 axes (tuple): Axes along which to integrate over. 

2046 

2047 Returns: 

2048 sparse integration matrix 

2049 """ 

2050 S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) 

2051 for axis in axes[1:]: 

2052 _S = self.axes[axis].get_integration_matrix() 

2053 S = S @ self.expand_matrix_ND(_S, axis) 

2054 

2055 return S 

2056 

2057 def get_Id(self): 

2058 """ 

2059 Get identity matrix 

2060 

2061 Returns: 

2062 sparse identity matrix 

2063 """ 

2064 I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) 

2065 for axis in range(1, self.ndim): 

2066 _I = self.axes[axis].get_Id() 

2067 I = I @ self.expand_matrix_ND(_I, axis) 

2068 return I 

2069 

2070 def get_Dirichlet_recombination_matrix(self, axis=-1): 

2071 """ 

2072 Get Dirichlet recombination matrix along axis. Not that it only makes sense in directions discretized with variations of Chebychev bases. 

2073 

2074 Args: 

2075 axis (int): Axis you discretized with Chebychev 

2076 

2077 Returns: 

2078 sparse matrix 

2079 """ 

2080 C1D = self.axes[axis].get_Dirichlet_recombination_matrix() 

2081 return self.expand_matrix_ND(C1D, axis) 

2082 

2083 def get_basis_change_matrix(self, axes=None, **kwargs): 

2084 """ 

2085 Some spectral bases do a change between bases while differentiating. This method returns matrices that changes the basis to whatever you want. 

2086 Refer to the methods of the same name of the 1D bases to learn what parameters you need to pass here as `kwargs`. 

2087 

2088 Args: 

2089 axes (tuple): Axes along which to change basis. 

2090 

2091 Returns: 

2092 sparse basis change matrix 

2093 """ 

2094 axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes 

2095 

2096 C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) 

2097 for axis in axes[1:]: 

2098 _C = self.axes[axis].get_basis_change_matrix(**kwargs) 

2099 C = C @ self.expand_matrix_ND(_C, axis) 

2100 

2101 self.logger.debug(f'Set up basis change matrix along axes {axes} with kwargs {kwargs}') 

2102 return C