Coverage for pySDC / core / sweeper.py: 94%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-20 16:04 +0000

1import logging 

2from typing import Any, Dict, Optional, TYPE_CHECKING 

3import numpy as np 

4from qmat.qdelta import QDeltaGenerator, QDELTA_GENERATORS 

5 

6from pySDC.core.errors import ParameterError 

7from pySDC.core.collocation import CollBase 

8from pySDC.helpers.pysdc_helper import FrozenClass 

9 

10if TYPE_CHECKING: 

11 from pySDC.core.level import Level 

12 

13# Organize QDeltaGenerator class in dict[type(QDeltaGenerator),set(str)] to retrieve aliases 

14QDELTA_GENERATORS_ALIASES = {v: set() for v in set(QDELTA_GENERATORS.values())} 

15for k, v in QDELTA_GENERATORS.items(): 

16 QDELTA_GENERATORS_ALIASES[v].add(k) 

17 

18 

19# short helper class to add params as attributes 

20class _Pars(FrozenClass): 

21 def __init__(self, pars: Dict[str, Any]) -> None: 

22 self.do_coll_update: bool = False 

23 self.initial_guess: str = 'spread' # default value (see also below) 

24 self.skip_residual_computation: tuple = () # gain performance at the cost of correct residual output 

25 

26 for k, v in pars.items(): 

27 if k != 'collocation_class': 

28 setattr(self, k, v) 

29 

30 self._freeze() 

31 

32 

33class Sweeper(object): 

34 """ 

35 Base abstract sweeper class, provides two base methods to generate QDelta matrices: 

36 

37 - get_Qdelta_implicit(qd_type): 

38 Returns a (pySDC-type) QDelta matrix of **implicit type**, 

39 *i.e* lower triangular with zeros on the first collumn. 

40 - get_Qdelta_explicit(qd_type): 

41 Returns a (pySDC-type) QDelta matrix of **explicit type**, 

42 *i.e* strictly lower triangular with first node distance to zero on the first collumn. 

43 

44 

45 All possible QDelta matrix coefficients are generated with 

46 `qmat <https://qmat.readthedocs.io/en/latest/autoapi/qmat/qdelta/index.html>`_, 

47 check it out to see all available coefficient types. 

48 

49 Attributes: 

50 logger: custom logger for sweeper-related logging 

51 params (__Pars): parameter object containing the custom parameters passed by the user 

52 coll (pySDC.Collocation.CollBase): collocation object 

53 """ 

54 

55 def __init__(self, params: Dict[str, Any], level: 'Level') -> None: 

56 """ 

57 Initialization routine for the base sweeper 

58 

59 Args: 

60 params (dict): parameter object 

61 level (pySDC.Level.level): the level that uses this sweeper 

62 """ 

63 

64 self.logger: logging.Logger = logging.getLogger('sweeper') 

65 

66 essential_keys = ['num_nodes'] 

67 for key in essential_keys: 

68 if key not in params: 

69 msg = 'need %s to instantiate step, only got %s' % (key, str(params.keys())) 

70 self.logger.error(msg) 

71 raise ParameterError(msg) 

72 

73 if 'collocation_class' not in params: 

74 params['collocation_class'] = CollBase 

75 

76 # prepare random generator for initial guess 

77 if params.get('initial_guess', 'spread') == 'random': # default value (see also above) 

78 params['random_seed'] = params.get('random_seed', 1984) 

79 self.rng = np.random.RandomState(params['random_seed']) 

80 

81 self.params = _Pars(params) 

82 

83 self.coll: CollBase = params['collocation_class'](**params) 

84 

85 if not self.coll.right_is_node and not self.params.do_coll_update: 

86 self.logger.warning( 

87 'we need to do a collocation update here, since the right end point is not a node. Changing this!' 

88 ) 

89 self.params.do_coll_update = True 

90 

91 self.__level = level 

92 self.parallelizable = False 

93 for name in ["genQI", "genQE"]: 

94 if hasattr(self, name): 

95 delattr(self, name) 

96 

97 def buildGenerator(self, qdType: str) -> QDeltaGenerator: 

98 return QDELTA_GENERATORS[qdType](qGen=self.coll.generator, tLeft=self.coll.tleft) 

99 

100 def get_Qdelta_implicit(self, qd_type: str, k: Optional[int] = None) -> np.ndarray: 

101 QDmat = np.zeros_like(self.coll.Qmat) 

102 if not hasattr(self, "genQI") or qd_type not in QDELTA_GENERATORS_ALIASES[type(self.genQI)]: 

103 self.genQI: QDeltaGenerator = self.buildGenerator(qd_type) 

104 QDmat[1:, 1:] = self.genQI.genCoeffs(k=k) 

105 

106 err_msg = 'Lower triangular matrix expected!' 

107 np.testing.assert_array_equal(np.triu(QDmat, k=1), np.zeros(QDmat.shape), err_msg=err_msg) 

108 if np.allclose(np.diag(np.diag(QDmat)), QDmat): 

109 self.parallelizable = True 

110 return QDmat 

111 

112 def get_Qdelta_explicit(self, qd_type: str, k: Optional[int] = None) -> np.ndarray: 

113 coll = self.coll 

114 QDmat = np.zeros(coll.Qmat.shape, dtype=float) 

115 if not hasattr(self, "genQE") or qd_type not in QDELTA_GENERATORS_ALIASES[type(self.genQE)]: 

116 self.genQE: QDeltaGenerator = self.buildGenerator(qd_type) 

117 QDmat[1:, 1:], QDmat[1:, 0] = self.genQE.genCoeffs(k=k, dTau=True) 

118 

119 err_msg = 'Strictly lower triangular matrix expected!' 

120 np.testing.assert_array_equal(np.triu(QDmat, k=0), np.zeros(QDmat.shape), err_msg=err_msg) 

121 if np.allclose(np.diag(np.diag(QDmat)), QDmat): 

122 self.parallelizable = True # for PIC ;) 

123 return QDmat 

124 

125 def predict(self) -> None: 

126 """ 

127 Predictor to fill values at nodes before first sweep 

128 

129 Default prediction for the sweepers, only copies the values to all collocation nodes 

130 and evaluates the RHS of the ODE there 

131 """ 

132 

133 # get current level and problem description 

134 L = self.level 

135 P = L.prob 

136 

137 # evaluate RHS at left point 

138 L.f[0] = P.eval_f(L.u[0], L.time) 

139 

140 for m in range(1, self.coll.num_nodes + 1): 

141 # copy u[0] to all collocation nodes, evaluate RHS 

142 if self.params.initial_guess == 'spread': 

143 L.u[m] = P.dtype_u(L.u[0]) 

144 L.f[m] = P.eval_f(L.u[m], L.time + L.dt * self.coll.nodes[m - 1]) 

145 # copy u[0] and RHS evaluation to all collocation nodes 

146 elif self.params.initial_guess == 'copy': 

147 L.u[m] = P.dtype_u(L.u[0]) 

148 L.f[m] = P.dtype_f(L.f[0]) 

149 # start with zero everywhere 

150 elif self.params.initial_guess == 'zero': 

151 L.u[m] = P.dtype_u(init=P.init, val=0.0) 

152 L.f[m] = P.dtype_f(init=P.init, val=0.0) 

153 # start with random initial guess 

154 elif self.params.initial_guess == 'random': 

155 L.u[m] = P.dtype_u(init=P.init, val=self.rng.rand(1)[0]) 

156 L.f[m] = P.dtype_f(init=P.init, val=self.rng.rand(1)[0]) 

157 else: 

158 raise ParameterError(f'initial_guess option {self.params.initial_guess} not implemented') 

159 

160 # indicate that this level is now ready for sweeps 

161 L.status.unlocked = True 

162 L.status.updated = True 

163 

164 def compute_residual(self, stage: str = '') -> None: 

165 """ 

166 Computation of the residual using the collocation matrix Q 

167 

168 Args: 

169 stage (str): The current stage of the step the level belongs to 

170 """ 

171 

172 # get current level and problem description 

173 L = self.level 

174 

175 # Check if we want to skip the residual computation to gain performance 

176 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual! 

177 if stage in self.params.skip_residual_computation: 

178 L.status.residual = 0.0 if L.status.residual is None else L.status.residual 

179 return None 

180 

181 # check if there are new values (e.g. from a sweep) 

182 # assert L.status.updated 

183 

184 # compute the residual for each node 

185 

186 # build QF(u) 

187 res_norm = [] 

188 L.residual = self.integrate() 

189 for m in range(self.coll.num_nodes): 

190 L.residual[m] += L.u[0] - L.u[m + 1] 

191 # add tau if associated 

192 if L.tau[m] is not None: 

193 L.residual[m] += L.tau[m] 

194 # use abs function from data type here 

195 res_norm.append(abs(L.residual[m])) 

196 

197 # find maximal residual over the nodes 

198 if L.params.residual_type == 'full_abs': 

199 L.status.residual = max(res_norm) 

200 elif L.params.residual_type == 'last_abs': 

201 L.status.residual = res_norm[-1] 

202 elif L.params.residual_type == 'full_rel': 

203 L.status.residual = max(res_norm) / abs(L.u[0]) 

204 elif L.params.residual_type == 'last_rel': 

205 L.status.residual = res_norm[-1] / abs(L.u[0]) 

206 else: 

207 raise ParameterError( 

208 f'residual_type = {L.params.residual_type} not implemented, choose ' 

209 f'full_abs, last_abs, full_rel or last_rel instead' 

210 ) 

211 

212 # indicate that the residual has seen the new values 

213 L.status.updated = False 

214 

215 return None 

216 

217 def compute_end_point(self) -> None: 

218 """ 

219 Abstract interface to end-node computation 

220 """ 

221 raise NotImplementedError('ERROR: sweeper has to implement compute_end_point(self)') 

222 

223 def integrate(self) -> Any: 

224 """ 

225 Abstract interface to right-hand side integration 

226 """ 

227 raise NotImplementedError('ERROR: sweeper has to implement integrate(self)') 

228 

229 def update_nodes(self) -> None: 

230 """ 

231 Abstract interface to node update 

232 """ 

233 raise NotImplementedError('ERROR: sweeper has to implement update_nodes(self)') 

234 

235 @property 

236 def level(self) -> 'Level': 

237 """ 

238 Returns the current level 

239 

240 Returns: 

241 pySDC.Level.level: the current level 

242 """ 

243 return self.__level 

244 

245 @level.setter 

246 def level(self, L: 'Level') -> None: 

247 """ 

248 Sets a reference to the current level (done in the initialization of the level) 

249 

250 Args: 

251 L (pySDC.Level.level): current level 

252 """ 

253 from pySDC.core.level import Level 

254 

255 assert isinstance(L, Level) 

256 self.__level = L 

257 

258 @property 

259 def rank(self) -> int: 

260 return 0 

261 

262 def updateVariableCoeffs(self, k: int) -> None: 

263 """ 

264 Potentially update QDelta implicit coefficients if variable ... 

265 

266 Parameters 

267 ---------- 

268 k : int 

269 Index of the sweep (0 for initial sweep, 1 for the first one, ...). 

270 """ 

271 if hasattr(self, "genQI") and self.genQI.isKDependent(): 

272 qdType = type(self.genQI).__name__ 

273 self.QI = self.get_Qdelta_implicit(qdType, k=k) 

274 if hasattr(self, "genQE") and self.genQE.isKDependent(): 

275 qdType = type(self.genQE).__name__ 

276 self.QE = self.get_Qdelta_explicit(qdType, k=k)