Coverage for pySDC/implementations/problem_classes/generic_spectral.py: 59%

198 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 11:36 +0000

1from pySDC.core.problem import Problem, WorkCounter 

2from pySDC.helpers.spectral_helper import SpectralHelper 

3import numpy as np 

4from pySDC.core.errors import ParameterError 

5from pySDC.helpers.fieldsIO import Rectilinear 

6 

7 

8class GenericSpectralLinear(Problem): 

9 """ 

10 Generic class to solve problems of the form M u_t + L u = y, with mass matrix M, linear operator L and some right 

11 hand side y using spectral methods. 

12 L may contain algebraic conditions, as long as (M + dt L) is invertible. 

13 

14 Note that the `__getattr__` method is overloaded to pass requests on to the spectral helper if they are not 

15 attributes of this class itself. For instance, you can add a BC by calling `self.spectral.add_BC` or equivalently 

16 `self.add_BC`. 

17 

18 You can port problems derived from this more or less seamlessly to GPU by using the numerical libraries that are 

19 class attributes of the spectral helper. This class will automatically switch the datatype using the `setup_GPU` class method. 

20 

21 Attributes: 

22 spectral (pySDC.helpers.spectral_helper.SpectralHelper): Spectral helper 

23 work_counters (dict): Dictionary for counting work 

24 cached_factorizations (dict): Dictionary of cached matrix factorizations for solving 

25 L (sparse matrix): Linear operator 

26 M (sparse matrix): Mass matrix 

27 diff_mask (list): Mask for separating differential and algebraic terms 

28 Pl (sparse matrix): Left preconditioner 

29 Pr (sparse matrix): Right preconditioner 

30 """ 

31 

32 def setup_GPU(self): 

33 """switch to GPU modules""" 

34 import cupy as cp 

35 from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh, imex_cupy_mesh 

36 from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh 

37 

38 self.dtype_u = cupy_mesh 

39 

40 GPU_versions = { 

41 mesh: cupy_mesh, 

42 imex_mesh: imex_cupy_mesh, 

43 } 

44 

45 self.dtype_f = GPU_versions[self.dtype_f] 

46 

47 if self.comm is not None: 

48 from pySDC.helpers.NCCL_communicator import NCCLComm 

49 

50 if not isinstance(self.comm, NCCLComm): 

51 self.__dict__['comm'] = NCCLComm(self.comm) 

52 

53 def __init__( 

54 self, 

55 bases, 

56 components, 

57 comm=None, 

58 Dirichlet_recombination=True, 

59 left_preconditioner=True, 

60 solver_type='cached_direct', 

61 solver_args=None, 

62 preconditioner_args=None, 

63 useGPU=False, 

64 max_cached_factorizations=12, 

65 spectral_space=True, 

66 real_spectral_coefficients=False, 

67 debug=False, 

68 ): 

69 """ 

70 Base class for problems discretized with spectral methods. 

71 

72 Args: 

73 bases (list of dictionaries): 1D Bases 

74 components (list of strings): Components of the equations 

75 comm (mpi4py.Intracomm or None): MPI communicator 

76 Dirichlet_recombination (bool): Use Dirichlet recombination in the last axis as right preconditioner 

77 left_preconditioner (bool): Reverse the Kronecker product if yes 

78 solver_type (str): Solver for linear systems 

79 solver_args (dict): Arguments for linear solver 

80 useGPU (bool): Run on GPU or CPU 

81 max_cached_factorizations (int): Number of matrix decompositions to cache before starting eviction 

82 spectral_space (bool): If yes, the solution will not be transformed back after solving and evaluating the RHS, and is expected as input in spectral space to these functions 

83 real_spectral_coefficients (bool): If yes, allow only real values in spectral space, otherwise, allow complex. 

84 debug (bool): Make additional tests at extra computational cost 

85 """ 

86 solver_args = {} if solver_args is None else solver_args 

87 

88 preconditioner_args = {} if preconditioner_args is None else preconditioner_args 

89 preconditioner_args['drop_tol'] = preconditioner_args.get('drop_tol', 1e-3) 

90 preconditioner_args['fill_factor'] = preconditioner_args.get('fill_factor', 100) 

91 

92 self._makeAttributeAndRegister( 

93 'max_cached_factorizations', 

94 'useGPU', 

95 'solver_type', 

96 'solver_args', 

97 'preconditioner_args', 

98 'left_preconditioner', 

99 'Dirichlet_recombination', 

100 'comm', 

101 'spectral_space', 

102 'real_spectral_coefficients', 

103 'debug', 

104 localVars=locals(), 

105 ) 

106 self.spectral = SpectralHelper(comm=comm, useGPU=useGPU, debug=debug) 

107 

108 if useGPU: 

109 self.setup_GPU() 

110 if self.solver_args is not None: 

111 if 'rtol' in self.solver_args.keys(): 

112 self.solver_args['tol'] = self.solver_args.pop('rtol') 

113 

114 for base in bases: 

115 self.spectral.add_axis(**base) 

116 self.spectral.add_component(components) 

117 

118 self.spectral.setup_fft(real_spectral_coefficients) 

119 

120 super().__init__(init=self.spectral.init_forward if spectral_space else self.spectral.init) 

121 

122 self.work_counters[solver_type] = WorkCounter() 

123 self.work_counters['factorizations'] = WorkCounter() 

124 

125 self.setup_preconditioner(Dirichlet_recombination, left_preconditioner) 

126 

127 self.cached_factorizations = {} 

128 

129 def __getattr__(self, name): 

130 """ 

131 Pass requests on to the helper if they are not directly attributes of this class for convenience. 

132 

133 Args: 

134 name (str): Name of the attribute you want 

135 

136 Returns: 

137 request 

138 """ 

139 return getattr(self.spectral, name) 

140 

141 def _setup_operator(self, LHS, diag=False): 

142 """ 

143 Setup a sparse linear operator by adding relationships. See documentation for ``GenericSpectralLinear.setup_L`` to learn more. 

144 

145 Args: 

146 LHS (dict): Equations to be added to the operator 

147 diag (bool): Whether operator is block-diagonal 

148 

149 Returns: 

150 sparse linear operator 

151 """ 

152 operator = self.spectral.get_empty_operator_matrix(diag=diag) 

153 for line, equation in LHS.items(): 

154 self.spectral.add_equation_lhs(operator, line, equation, diag=diag) 

155 return self.spectral.convert_operator_matrix_to_operator(operator, diag=diag) 

156 

157 def setup_L(self, LHS): 

158 """ 

159 Setup the left hand side of the linear operator L and store it in ``self.L``. 

160 

161 The argument is meant to be a dictionary with the line you want to write the equation in as the key and the relationship between components as another dictionary. For instance, you can add an algebraic condition capturing a first derivative relationship between u and ux as follows: 

162 

163 ``` 

164 Dx = self.get_differentiation_matrix(axes=(0,)) 

165 I = self.get_Id() 

166 LHS = {'ux': {'u': Dx, 'ux': -I}} 

167 self.setup_L(LHS) 

168 ``` 

169 

170 If you put zero as right hand side for the solver in the line for ux, ux will contain the x-derivative of u afterwards. 

171 

172 Args: 

173 LHS (dict): Dictionary containing the equations. 

174 """ 

175 self.L = self._setup_operator(LHS) 

176 

177 def setup_M(self, LHS, diag=True): 

178 ''' 

179 Setup mass matrix, see documentation of ``GenericSpectralLinear.setup_L``. 

180 ''' 

181 diff_index = list(LHS.keys()) 

182 self.diff_mask = [me in diff_index for me in self.components] 

183 self.M = self._setup_operator(LHS, diag=diag) 

184 

185 def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner=True): 

186 """ 

187 Get left and right preconditioners. 

188 

189 Args: 

190 Dirichlet_recombination (bool): Basis conversion for right preconditioner. Useful for Chebychev and Ultraspherical methods. 10/10 would recommend. 

191 left_preconditioner (bool): If True, it will interleave the variables and reverse the Kronecker product 

192 """ 

193 sp = self.spectral.sparse_lib 

194 N = np.prod(self.init[0][1:]) 

195 

196 Id = sp.eye(N) 

197 Pl_lhs = {comp: {comp: Id} for comp in self.components} 

198 self.Pl = self._setup_operator(Pl_lhs, diag=True) 

199 

200 if left_preconditioner: 

201 # reverse Kronecker product 

202 

203 if self.spectral.useGPU: 

204 R = self.Pl.get().tolil() * 0 

205 else: 

206 R = self.Pl.tolil() * 0 

207 

208 for j in range(self.ncomponents): 

209 for i in range(N): 

210 R[i * self.ncomponents + j, j * N + i] = 1.0 

211 

212 self.Pl = self.spectral.sparse_lib.csc_matrix(R) 

213 

214 if Dirichlet_recombination and type(self.axes[-1]).__name__ in ['ChebychevHelper', 'UltrasphericalHelper']: 

215 _Pr = self.spectral.get_Dirichlet_recombination_matrix(axis=-1) 

216 else: 

217 _Pr = Id 

218 

219 Pr_lhs = {comp: {comp: _Pr} for comp in self.components} 

220 self.Pr = self._setup_operator(Pr_lhs, diag=True) @ self.Pl.T 

221 

222 def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs): 

223 """ 

224 Do an implicit Euler step to solve M u_t + Lu = rhs, with M the mass matrix and L the linear operator as setup by 

225 ``GenericSpectralLinear.setup_L`` and ``GenericSpectralLinear.setup_M``. 

226 

227 The implicit Euler step is (M - dt L) u = M rhs. Note that M need not be invertible as long as (M + dt*L) is. 

228 This means solving with dt=0 to mimic explicit methods does not work for all problems, in particular simple DAEs. 

229 

230 Note that by putting M rhs on the right hand side, this function can only solve algebraic conditions equal to 

231 zero. If you want something else, it should be easy to overload this function. 

232 """ 

233 

234 sp = self.spectral.sparse_lib 

235 

236 if self.spectral_space: 

237 rhs_hat = rhs.copy() 

238 if u0 is not None: 

239 u0_hat = u0.copy().flatten() 

240 else: 

241 u0_hat = None 

242 else: 

243 rhs_hat = self.spectral.transform(rhs) 

244 if u0 is not None: 

245 u0_hat = self.spectral.transform(u0).flatten() 

246 else: 

247 u0_hat = None 

248 

249 # apply inverse right preconditioner to initial guess 

250 if u0_hat is not None and 'direct' not in self.solver_type: 

251 if not hasattr(self, '_Pr_inv'): 

252 self._PR_inv = self.linalg.splu(self.Pr.astype(complex)).solve 

253 u0_hat[...] = self._PR_inv(u0_hat) 

254 

255 rhs_hat = (self.M @ rhs_hat.flatten()).reshape(rhs_hat.shape) 

256 rhs_hat = self.spectral.put_BCs_in_rhs_hat(rhs_hat) 

257 rhs_hat = self.Pl @ rhs_hat.flatten() 

258 

259 if dt not in self.cached_factorizations.keys() or not self.solver_type.lower() == 'cached_direct': 

260 A = self.M + dt * self.L 

261 A = self.Pl @ self.spectral.put_BCs_in_matrix(A) @ self.Pr 

262 

263 # if A.shape[0] < 200e20: 

264 # import matplotlib.pyplot as plt 

265 

266 # # M = self.spectral.put_BCs_in_matrix(self.L.copy()) 

267 # M = A # self.L 

268 # im = plt.spy(M) 

269 # plt.show() 

270 

271 if 'ilu' in self.solver_type.lower(): 

272 if dt not in self.cached_factorizations.keys(): 

273 if len(self.cached_factorizations) >= self.max_cached_factorizations: 

274 to_evict = list(self.cached_factorizations.keys())[0] 

275 self.cached_factorizations.pop(to_evict) 

276 self.logger.debug(f'Evicted matrix factorization for {to_evict=:.6f} from cache') 

277 iLU = self.linalg.spilu( 

278 A, **{**self.preconditioner_args, 'drop_tol': dt * self.preconditioner_args['drop_tol']} 

279 ) 

280 self.cached_factorizations[dt] = self.linalg.LinearOperator(A.shape, iLU.solve) 

281 self.logger.debug(f'Cached incomplete LU factorization for {dt=:.6f}') 

282 self.work_counters['factorizations']() 

283 M = self.cached_factorizations[dt] 

284 else: 

285 M = None 

286 info = 0 

287 

288 if self.solver_type.lower() == 'cached_direct': 

289 if dt not in self.cached_factorizations.keys(): 

290 if len(self.cached_factorizations) >= self.max_cached_factorizations: 

291 self.cached_factorizations.pop(list(self.cached_factorizations.keys())[0]) 

292 self.logger.debug(f'Evicted matrix factorization for {dt=:.6f} from cache') 

293 self.cached_factorizations[dt] = self.spectral.linalg.factorized(A) 

294 self.logger.debug(f'Cached matrix factorization for {dt=:.6f}') 

295 self.work_counters['factorizations']() 

296 

297 _sol_hat = self.cached_factorizations[dt](rhs_hat) 

298 self.logger.debug(f'Used cached matrix factorization for {dt=:.6f}') 

299 

300 elif self.solver_type.lower() == 'direct': 

301 _sol_hat = sp.linalg.spsolve(A, rhs_hat) 

302 elif 'gmres' in self.solver_type.lower(): 

303 _sol_hat, _ = sp.linalg.gmres( 

304 A, 

305 rhs_hat, 

306 x0=u0_hat, 

307 **self.solver_args, 

308 callback=self.work_counters[self.solver_type], 

309 callback_type='pr_norm', 

310 M=M, 

311 ) 

312 elif self.solver_type.lower() == 'cg': 

313 _sol_hat, info = sp.linalg.cg( 

314 A, rhs_hat, x0=u0_hat, **self.solver_args, callback=self.work_counters[self.solver_type] 

315 ) 

316 elif 'bicgstab' in self.solver_type.lower(): 

317 _sol_hat, info = self.linalg.bicgstab( 

318 A, 

319 rhs_hat, 

320 x0=u0_hat, 

321 **self.solver_args, 

322 callback=self.work_counters[self.solver_type], 

323 M=M, 

324 ) 

325 else: 

326 raise NotImplementedError(f'Solver {self.solver_type=} not implemented in {type(self).__name__}!') 

327 

328 if info != 0: 

329 self.logger.warn(f'{self.solver_type} not converged! {info=}') 

330 

331 sol_hat = self.spectral.u_init_forward 

332 sol_hat[...] = (self.Pr @ _sol_hat).reshape(sol_hat.shape) 

333 

334 if self.spectral_space: 

335 return sol_hat 

336 else: 

337 sol = self.spectral.u_init 

338 sol[:] = self.spectral.itransform(sol_hat).real 

339 

340 if self.spectral.debug: 

341 self.spectral.check_BCs(sol) 

342 

343 return sol 

344 

345 def setUpFieldsIO(self): 

346 Rectilinear.setupMPI( 

347 comm=self.comm, 

348 iLoc=[me.start for me in self.local_slice(False)], 

349 nLoc=[me.stop - me.start for me in self.local_slice(False)], 

350 ) 

351 

352 def getOutputFile(self, fileName): 

353 self.setUpFieldsIO() 

354 

355 coords = [me.get_1dgrid() for me in self.spectral.axes] 

356 assert np.allclose([len(me) for me in coords], self.spectral.global_shape[1:]) 

357 

358 fOut = Rectilinear(np.float64, fileName=fileName) 

359 fOut.setHeader(nVar=len(self.components), coords=coords) 

360 fOut.initialize() 

361 return fOut 

362 

363 def processSolutionForOutput(self, u): 

364 if self.spectral_space: 

365 return np.array(self.itransform(u).real) 

366 else: 

367 return np.array(u.real) 

368 

369 

370def compute_residual_DAE(self, stage=''): 

371 """ 

372 Computation of the residual that does not add u_0 - u_m in algebraic equations. 

373 

374 Args: 

375 stage (str): The current stage of the step the level belongs to 

376 """ 

377 

378 # get current level and problem description 

379 L = self.level 

380 

381 # Check if we want to skip the residual computation to gain performance 

382 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual! 

383 if stage in self.params.skip_residual_computation: 

384 L.status.residual = 0.0 if L.status.residual is None else L.status.residual 

385 return None 

386 

387 # check if there are new values (e.g. from a sweep) 

388 # assert L.status.updated 

389 

390 # compute the residual for each node 

391 

392 # build QF(u) 

393 res_norm = [] 

394 res = self.integrate() 

395 mask = L.prob.diff_mask 

396 for m in range(self.coll.num_nodes): 

397 res[m][mask] += L.u[0][mask] - L.u[m + 1][mask] 

398 # add tau if associated 

399 if L.tau[m] is not None: 

400 res[m] += L.tau[m] 

401 # use abs function from data type here 

402 res_norm.append(abs(res[m])) 

403 

404 # find maximal residual over the nodes 

405 if L.params.residual_type == 'full_abs': 

406 L.status.residual = max(res_norm) 

407 elif L.params.residual_type == 'last_abs': 

408 L.status.residual = res_norm[-1] 

409 elif L.params.residual_type == 'full_rel': 

410 L.status.residual = max(res_norm) / abs(L.u[0]) 

411 elif L.params.residual_type == 'last_rel': 

412 L.status.residual = res_norm[-1] / abs(L.u[0]) 

413 else: 

414 raise ParameterError( 

415 f'residual_type = {L.params.residual_type} not implemented, choose ' 

416 f'full_abs, last_abs, full_rel or last_rel instead' 

417 ) 

418 

419 # indicate that the residual has seen the new values 

420 L.status.updated = False 

421 

422 return None 

423 

424 

425def compute_residual_DAE_MPI(self, stage=None): 

426 """ 

427 Computation of the residual using the collocation matrix Q 

428 

429 Args: 

430 stage (str): The current stage of the step the level belongs to 

431 """ 

432 from mpi4py import MPI 

433 

434 L = self.level 

435 

436 # Check if we want to skip the residual computation to gain performance 

437 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual! 

438 if stage in self.params.skip_residual_computation: 

439 L.status.residual = 0.0 if L.status.residual is None else L.status.residual 

440 return None 

441 

442 # compute the residual for each node 

443 

444 # build QF(u) 

445 res = self.integrate(last_only=L.params.residual_type[:4] == 'last') 

446 mask = L.prob.diff_mask 

447 res[mask] += L.u[0][mask] - L.u[self.rank + 1][mask] 

448 # add tau if associated 

449 if L.tau[self.rank] is not None: 

450 res += L.tau[self.rank] 

451 # use abs function from data type here 

452 res_norm = abs(res) 

453 

454 # find maximal residual over the nodes 

455 if L.params.residual_type == 'full_abs': 

456 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX) 

457 elif L.params.residual_type == 'last_abs': 

458 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1) 

459 elif L.params.residual_type == 'full_rel': 

460 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX) 

461 elif L.params.residual_type == 'last_rel': 

462 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1) 

463 else: 

464 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!') 

465 

466 # indicate that the residual has seen the new values 

467 L.status.updated = False 

468 

469 return None 

470 

471 

472def get_extrapolated_error_DAE(self, S, **kwargs): 

473 """ 

474 The extrapolation estimate combines values of u and f from multiple steps to extrapolate and compare to the 

475 solution obtained by the time marching scheme. This function can be used in `EstimateExtrapolationError`. 

476 

477 Args: 

478 S (pySDC.Step): The current step 

479 

480 Returns: 

481 None 

482 """ 

483 u_ex = self.get_extrapolated_solution(S) 

484 diff_mask = S.levels[0].prob.diff_mask 

485 if u_ex is not None: 

486 S.levels[0].status.error_extrapolation_estimate = ( 

487 abs((u_ex - S.levels[0].u[-1])[diff_mask]) * self.coeff.prefactor 

488 ) 

489 # print([abs(me) for me in (u_ex - S.levels[0].u[-1]) * self.coeff.prefactor]) 

490 else: 

491 S.levels[0].status.error_extrapolation_estimate = None