Coverage for pySDC/projects/PinTSimE/switch_estimator.py: 99%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-20 16:55 +0000

1import numpy as np 

2import scipy as sp 

3 

4from pySDC.core.errors import ParameterError 

5from pySDC.core.collocation import CollBase 

6from pySDC.core.convergence_controller import ConvergenceController, Status 

7from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

8from qmat.lagrange import LagrangeApproximation 

9 

10 

11class SwitchEstimator(ConvergenceController): 

12 """ 

13 Class to predict the time point of the event and setting a new step size. For the first time, this is a nonMPI version, 

14 because a MPI version is not yet developed. 

15 """ 

16 

17 def setup(self, controller, params, description): 

18 r""" 

19 Function sets default variables to handle with the event at the beginning. The default params are: 

20 

21 - control_order : controls the order of the SE's call of convergence controllers. 

22 - coll.nodes : defines the collocation nodes for interpolation. 

23 - tol_zero : inner tolerance for SE; state function has to satisfy it to terminate. 

24 - t_interp : interpolation axis with time points. 

25 - state_function : List of values from state function. 

26 

27 Parameters 

28 ---------- 

29 controller : pySDC.Controller 

30 The controller doing all the stuff in a computation. 

31 params : dict 

32 The parameters passed for this specific convergence controller. 

33 description : dict 

34 The description object used to instantiate the controller. 

35 

36 Returns 

37 ------- 

38 convergence_controller_params : dict 

39 The updated params dictionary. 

40 """ 

41 

42 # for RK4 sweeper, sweep.coll.nodes now consists of values of ButcherTableau 

43 # for this reason, collocation nodes will be generated here 

44 coll = CollBase( 

45 num_nodes=description['sweeper_params']['num_nodes'], 

46 quad_type=description['sweeper_params']['quad_type'], 

47 ) 

48 

49 defaults = { 

50 'control_order': 0, 

51 'nodes': coll.nodes, 

52 'tol_zero': 2.5e-12, 

53 't_interp': [], 

54 'state_function': [], 

55 } 

56 return {**defaults, **params} 

57 

58 def setup_status_variables(self, controller, **kwargs): 

59 """ 

60 Adds switching specific variables to status variables. 

61 

62 Parameters 

63 ---------- 

64 controller : pySDC.Controller 

65 The controller doing all the stuff in a computation. 

66 """ 

67 

68 self.status = Status(['is_zero', 'switch_detected', 't_switch']) 

69 

70 def reset_status_variables(self, controller, **kwargs): 

71 """ 

72 Resets status variables. 

73 

74 Parameters 

75 ---------- 

76 controller : pySDC.Controller 

77 The controller doing all the stuff in a computation. 

78 """ 

79 

80 self.setup_status_variables(controller, **kwargs) 

81 

82 def get_new_step_size(self, controller, S, **kwargs): 

83 """ 

84 Determine a new step size when an event is found such that the event occurs at the time step. 

85 

86 Parameters 

87 ---------- 

88 controller : pySDC.Controller 

89 The controller doing all the stuff in a computation. 

90 S : pySDC.Step 

91 The current step. 

92 """ 

93 

94 L = S.levels[0] 

95 

96 if CheckConvergence.check_convergence(S): 

97 self.status.switch_detected, m_guess, self.params.state_function = L.prob.get_switching_info(L.u, L.time) 

98 

99 if self.status.switch_detected: 

100 self.params.t_interp = [L.time + L.dt * self.params.nodes[m] for m in range(len(self.params.nodes))] 

101 self.params.t_interp, self.params.state_function = self.adapt_interpolation_info( 

102 L.time, L.sweep.coll.left_is_node, self.params.t_interp, self.params.state_function 

103 ) 

104 

105 # when the state function is already close to zero the event is already resolved well 

106 if ( 

107 abs(self.params.state_function[-1]) <= self.params.tol_zero 

108 or abs(self.params.state_function[0]) <= self.params.tol_zero 

109 ): 

110 if abs(self.params.state_function[0]) <= self.params.tol_zero: 

111 t_switch = self.params.t_interp[0] 

112 boundary = 'left' 

113 elif abs(self.params.state_function[-1]) <= self.params.tol_zero: 

114 boundary = 'right' 

115 t_switch = self.params.t_interp[-1] 

116 

117 msg = f"The value of state function is close to zero, thus event time is already close enough to the {boundary} end point!" 

118 self.log(msg, S) 

119 self.log_event_time( 

120 controller.hooks[0], S.status.slot, L.time, L.level_index, L.status.sweep, t_switch 

121 ) 

122 

123 L.prob.count_switches() 

124 self.status.is_zero = True 

125 

126 # intermediate value theorem states that a root is contained in current step 

127 if self.params.state_function[0] * self.params.state_function[-1] < 0 and self.status.is_zero is None: 

128 self.status.t_switch = self.get_switch(self.params.t_interp, self.params.state_function, m_guess) 

129 

130 self.logging_during_estimation( 

131 controller.hooks[0], 

132 S.status.slot, 

133 L.time, 

134 L.level_index, 

135 L.status.sweep, 

136 self.status.t_switch, 

137 self.params.state_function, 

138 ) 

139 

140 if L.time < self.status.t_switch < L.time + L.dt: 

141 dt_switch = (self.status.t_switch - L.time) * self.params.alpha 

142 

143 if ( 

144 abs(self.status.t_switch - L.time) <= self.params.tol 

145 or abs((L.time + L.dt) - self.status.t_switch) <= self.params.tol 

146 ): 

147 self.log(f"Switch located at time {self.status.t_switch:.15f}", S) 

148 L.prob.t_switch = self.status.t_switch 

149 self.log_event_time( 

150 controller.hooks[0], 

151 S.status.slot, 

152 L.time, 

153 L.level_index, 

154 L.status.sweep, 

155 self.status.t_switch, 

156 ) 

157 

158 L.prob.count_switches() 

159 

160 else: 

161 self.log(f"Located Switch at time {self.status.t_switch:.15f} is outside the range", S) 

162 

163 # when an event is found, step size matching with this event should be preferred 

164 dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt 

165 if self.status.switch_detected: 

166 L.status.dt_new = dt_switch 

167 else: 

168 L.status.dt_new = min([dt_planned, dt_switch]) 

169 

170 else: 

171 # event occurs on L.time or L.time + L.dt; no restart necessary 

172 boundary = 'left boundary' if self.status.t_switch == L.time else 'right boundary' 

173 self.log(f"Estimated switch {self.status.t_switch:.15f} occurs at {boundary}", S) 

174 self.log_event_time( 

175 controller.hooks[0], 

176 S.status.slot, 

177 L.time, 

178 L.level_index, 

179 L.status.sweep, 

180 self.status.t_switch, 

181 ) 

182 L.prob.count_switches() 

183 self.status.switch_detected = False 

184 

185 else: # intermediate value theorem is not satisfied 

186 self.status.switch_detected = False 

187 

188 def determine_restart(self, controller, S, **kwargs): 

189 """ 

190 Check if the step needs to be restarted due to a predicting switch. 

191 

192 Parameters 

193 ---------- 

194 controller : pySDC.Controller 

195 The controller doing all the stuff in a computation. 

196 S : pySDC.Step 

197 The current step. 

198 """ 

199 

200 if self.status.switch_detected: 

201 S.status.restart = True 

202 S.status.force_done = True 

203 

204 super().determine_restart(controller, S, **kwargs) 

205 

206 def post_step_processing(self, controller, S, **kwargs): 

207 """ 

208 After a step is done, some variables will be prepared for predicting a possibly new switch. 

209 If no Adaptivity is used, the next time step will be set as the default one from the front end. 

210 

211 Parameters 

212 ---------- 

213 controller : pySDC.Controller 

214 The controller doing all the stuff in a computation. 

215 S : pySDC.Step 

216 The current step. 

217 """ 

218 

219 L = S.levels[0] 

220 

221 if self.status.t_switch is None: 

222 L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt_initial 

223 

224 super().post_step_processing(controller, S, **kwargs) 

225 

226 @staticmethod 

227 def log_event_time(controller_hooks, process, time, level, sweep, t_switch): 

228 """ 

229 Logs the event time of an event satisfying an appropriate criterion, e.g., event is already resolved well, 

230 event time satisfies tolerance. 

231 

232 Parameters 

233 ---------- 

234 controller_hooks : pySDC.Controller.hooks 

235 Controller with access to the hooks. 

236 process : int 

237 Process for logging. 

238 time : float 

239 Time at which the event time is logged (denotes the current step). 

240 level : int 

241 Level at which event is found. 

242 sweep : int 

243 Denotes the number of sweep. 

244 t_switch : float 

245 Event time founded by switch estimation. 

246 """ 

247 

248 controller_hooks.add_to_stats( 

249 process=process, 

250 time=time, 

251 level=level, 

252 iter=0, 

253 sweep=sweep, 

254 type='switch', 

255 value=t_switch, 

256 ) 

257 

258 @staticmethod 

259 def logging_during_estimation(controller_hooks, process, time, level, sweep, t_switch, state_function): 

260 controller_hooks.add_to_stats( 

261 process=process, 

262 time=time, 

263 level=level, 

264 iter=0, 

265 sweep=sweep, 

266 type='switch_all', 

267 value=t_switch, 

268 ) 

269 controller_hooks.add_to_stats( 

270 process=process, 

271 time=time, 

272 level=level, 

273 iter=0, 

274 sweep=sweep, 

275 type='h_all', 

276 value=max([abs(item) for item in state_function]), 

277 ) 

278 

279 @staticmethod 

280 def get_switch(t_interp, state_function, m_guess): 

281 r""" 

282 Routine to do the interpolation and root finding stuff. 

283 

284 Parameters 

285 ---------- 

286 t_interp : list 

287 Collocation nodes in a step. 

288 state_function : list 

289 Contains values of state function at these collocation nodes. 

290 m_guess : float 

291 Index at which the difference drops below zero. 

292 

293 Returns 

294 ------- 

295 t_switch : float 

296 Time point of found event. 

297 """ 

298 

299 LagrangeInterpolation = LagrangeApproximation(points=t_interp, fValues=state_function) 

300 

301 def p(t): 

302 return LagrangeInterpolation.__call__(t) 

303 

304 def fprime(t): 

305 r""" 

306 Computes the derivative of the scalar interpolant using finite difference. Here, 

307 the derivative is approximated by the backward difference: 

308 

309 .. math:: 

310 \frac{dp}{dt} \approx \frac{25 p(t) - 48 p(t - h) + 36 p(t - 2 h) - 16 p(t - 3h) + 3 p(t - 4 h)}{12 h} 

311 

312 

313 Parameters 

314 ---------- 

315 t : float 

316 Time where the derivatives is computed. 

317 

318 Returns 

319 ------- 

320 dp : float 

321 Derivative of interpolation p at time t. 

322 """ 

323 

324 dt_FD = 1e-10 

325 dp = ( 

326 25 * p(t) - 48 * p(t - dt_FD) + 36 * p(t - 2 * dt_FD) - 16 * p(t - 3 * dt_FD) + 3 * p(t - 4 * dt_FD) 

327 ) / (12 * dt_FD) 

328 return dp 

329 

330 newton_tol, newton_maxiter = 1e-14, 100 

331 t_switch = newton(t_interp[m_guess], p, fprime, newton_tol, newton_maxiter) 

332 return t_switch 

333 

334 @staticmethod 

335 def adapt_interpolation_info(t, left_is_node, t_interp, state_function): 

336 """ 

337 Adapts the x- and y-axis for interpolation. For SDC, it is proven whether the left boundary is a 

338 collocation node or not. In case it is, the first entry of the state function has to be removed, 

339 because it would otherwise contain double values on starting time and the first node. Otherwise, 

340 starting time L.time has to be added to t_interp to also take this value in the interpolation 

341 into account. 

342 

343 Parameters 

344 ---------- 

345 t : float 

346 Starting time of the step. 

347 left_is_node : bool 

348 Indicates whether the left boundary is a collocation node or not. 

349 t_interp : list 

350 x-values for interpolation containing collocation nodes. 

351 state_function : list 

352 y-values for interpolation containing values of state function. 

353 

354 Returns 

355 ------- 

356 t_interp : list 

357 Adapted x-values for interpolation containing collocation nodes. 

358 state_function : list 

359 Adapted y-values for interpolation containing values of state function. 

360 """ 

361 

362 if not left_is_node: 

363 t_interp.insert(0, t) 

364 else: 

365 del state_function[0] 

366 

367 return t_interp, state_function 

368 

369 

370def newton(x0, p, fprime, newton_tol, newton_maxiter): 

371 """ 

372 Newton's method fo find the root of interpolant p. 

373 

374 Parameters 

375 ---------- 

376 x0 : float 

377 Initial guess. 

378 p : callable 

379 Interpolated function where Newton's method is applied at. 

380 fprime : callable 

381 Approximated derivative of p using finite differences. 

382 newton_tol : float 

383 Tolerance for termination. 

384 newton_maxiter : int 

385 Maximum of iterations the method should execute. 

386 

387 Returns 

388 ------- 

389 root : float 

390 Root of function p. 

391 """ 

392 

393 n = 0 

394 while n < newton_maxiter: 

395 res = abs(p(x0)) 

396 if res < newton_tol or np.isnan(p(x0)) and np.isnan(fprime(x0)) or np.isclose(fprime(x0), 0.0): 

397 break 

398 

399 x0 -= 1.0 / fprime(x0) * p(x0) 

400 

401 n += 1 

402 

403 if n == newton_maxiter: 

404 msg = f'Newton did not converge after {n} iterations, error is {res}' 

405 else: 

406 msg = f'Newton did converge after {n} iterations, error for root {x0} is {res}' 

407 print(msg) 

408 

409 root = x0 

410 return root