Coverage for pySDC/projects/PinTSimE/switch_estimator.py: 99%

104 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import numpy as np 

2import scipy as sp 

3 

4from pySDC.core.Errors import ParameterError 

5from pySDC.core.Collocation import CollBase 

6from pySDC.core.ConvergenceController import ConvergenceController, Status 

7from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

8from pySDC.core.Lagrange import LagrangeApproximation 

9 

10 

11class SwitchEstimator(ConvergenceController): 

12 """ 

13 Class to predict the time point of the event and setting a new step size. For the first time, this is a nonMPI version, 

14 because a MPI version is not yet developed. 

15 """ 

16 

17 def setup(self, controller, params, description): 

18 r""" 

19 Function sets default variables to handle with the event at the beginning. The default params are: 

20 

21 - control_order : controls the order of the SE's call of convergence controllers. 

22 - coll.nodes : defines the collocation nodes for interpolation. 

23 - tol_zero : inner tolerance for SE; state function has to satisfy it to terminate. 

24 - t_interp : interpolation axis with time points. 

25 - state_function : List of values from state function. 

26 

27 Parameters 

28 ---------- 

29 controller : pySDC.Controller 

30 The controller doing all the stuff in a computation. 

31 params : dict 

32 The parameters passed for this specific convergence controller. 

33 description : dict 

34 The description object used to instantiate the controller. 

35 

36 Returns 

37 ------- 

38 convergence_controller_params : dict 

39 The updated params dictionary. 

40 """ 

41 

42 # for RK4 sweeper, sweep.coll.nodes now consists of values of ButcherTableau 

43 # for this reason, collocation nodes will be generated here 

44 coll = CollBase( 

45 num_nodes=description['sweeper_params']['num_nodes'], 

46 quad_type=description['sweeper_params']['quad_type'], 

47 ) 

48 

49 defaults = { 

50 'control_order': 0, 

51 'nodes': coll.nodes, 

52 'tol_zero': 2.5e-12, 

53 't_interp': [], 

54 'state_function': [], 

55 } 

56 return {**defaults, **params} 

57 

58 def setup_status_variables(self, controller, **kwargs): 

59 """ 

60 Adds switching specific variables to status variables. 

61 

62 Parameters 

63 ---------- 

64 controller : pySDC.Controller 

65 The controller doing all the stuff in a computation. 

66 """ 

67 

68 self.status = Status(['is_zero', 'switch_detected', 't_switch']) 

69 

70 def reset_status_variables(self, controller, **kwargs): 

71 """ 

72 Resets status variables. 

73 

74 Parameters 

75 ---------- 

76 controller : pySDC.Controller 

77 The controller doing all the stuff in a computation. 

78 """ 

79 

80 self.setup_status_variables(controller, **kwargs) 

81 

82 def get_new_step_size(self, controller, S, **kwargs): 

83 """ 

84 Determine a new step size when an event is found such that the event occurs at the time step. 

85 

86 Parameters 

87 ---------- 

88 controller : pySDC.Controller 

89 The controller doing all the stuff in a computation. 

90 S : pySDC.Step 

91 The current step. 

92 """ 

93 

94 L = S.levels[0] 

95 

96 if CheckConvergence.check_convergence(S): 

97 self.status.switch_detected, m_guess, self.params.state_function = L.prob.get_switching_info(L.u, L.time) 

98 

99 if self.status.switch_detected: 

100 self.params.t_interp = [L.time + L.dt * self.params.nodes[m] for m in range(len(self.params.nodes))] 

101 self.params.t_interp, self.params.state_function = self.adapt_interpolation_info( 

102 L.time, L.sweep.coll.left_is_node, self.params.t_interp, self.params.state_function 

103 ) 

104 

105 # when the state function is already close to zero the event is already resolved well 

106 if ( 

107 abs(self.params.state_function[-1]) <= self.params.tol_zero 

108 or abs(self.params.state_function[0]) <= self.params.tol_zero 

109 ): 

110 if abs(self.params.state_function[0]) <= self.params.tol_zero: 

111 t_switch = self.params.t_interp[0] 

112 boundary = 'left' 

113 elif abs(self.params.state_function[-1]) <= self.params.tol_zero: 

114 boundary = 'right' 

115 t_switch = self.params.t_interp[-1] 

116 

117 msg = f"The value of state function is close to zero, thus event time is already close enough to the {boundary} end point!" 

118 self.log(msg, S) 

119 self.log_event_time( 

120 controller.hooks[0], S.status.slot, L.time, L.level_index, L.status.sweep, t_switch 

121 ) 

122 

123 L.prob.count_switches() 

124 self.status.is_zero = True 

125 

126 # intermediate value theorem states that a root is contained in current step 

127 if self.params.state_function[0] * self.params.state_function[-1] < 0 and self.status.is_zero is None: 

128 self.status.t_switch = self.get_switch(self.params.t_interp, self.params.state_function, m_guess) 

129 

130 self.logging_during_estimation( 

131 controller.hooks[0], 

132 S.status.slot, 

133 L.time, 

134 L.level_index, 

135 L.status.sweep, 

136 self.status.t_switch, 

137 self.params.state_function, 

138 ) 

139 

140 if L.time < self.status.t_switch < L.time + L.dt: 

141 dt_switch = (self.status.t_switch - L.time) * self.params.alpha 

142 

143 if ( 

144 abs(self.status.t_switch - L.time) <= self.params.tol 

145 or abs((L.time + L.dt) - self.status.t_switch) <= self.params.tol 

146 ): 

147 self.log(f"Switch located at time {self.status.t_switch:.15f}", S) 

148 L.prob.t_switch = self.status.t_switch 

149 self.log_event_time( 

150 controller.hooks[0], 

151 S.status.slot, 

152 L.time, 

153 L.level_index, 

154 L.status.sweep, 

155 self.status.t_switch, 

156 ) 

157 

158 L.prob.count_switches() 

159 

160 else: 

161 self.log(f"Located Switch at time {self.status.t_switch:.15f} is outside the range", S) 

162 

163 # when an event is found, step size matching with this event should be preferred 

164 dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt 

165 if self.status.switch_detected: 

166 L.status.dt_new = dt_switch 

167 else: 

168 L.status.dt_new = min([dt_planned, dt_switch]) 

169 

170 else: 

171 # event occurs on L.time or L.time + L.dt; no restart necessary 

172 boundary = 'left boundary' if self.status.t_switch == L.time else 'right boundary' 

173 self.log(f"Estimated switch {self.status.t_switch:.15f} occurs at {boundary}", S) 

174 self.log_event_time( 

175 controller.hooks[0], 

176 S.status.slot, 

177 L.time, 

178 L.level_index, 

179 L.status.sweep, 

180 self.status.t_switch, 

181 ) 

182 L.prob.count_switches() 

183 self.status.switch_detected = False 

184 

185 else: # intermediate value theorem is not satisfied 

186 self.status.switch_detected = False 

187 

188 def determine_restart(self, controller, S, **kwargs): 

189 """ 

190 Check if the step needs to be restarted due to a predicting switch. 

191 

192 Parameters 

193 ---------- 

194 controller : pySDC.Controller 

195 The controller doing all the stuff in a computation. 

196 S : pySDC.Step 

197 The current step. 

198 """ 

199 

200 if self.status.switch_detected: 

201 S.status.restart = True 

202 S.status.force_done = True 

203 

204 super().determine_restart(controller, S, **kwargs) 

205 

206 def post_step_processing(self, controller, S, **kwargs): 

207 """ 

208 After a step is done, some variables will be prepared for predicting a possibly new switch. 

209 If no Adaptivity is used, the next time step will be set as the default one from the front end. 

210 

211 Parameters 

212 ---------- 

213 controller : pySDC.Controller 

214 The controller doing all the stuff in a computation. 

215 S : pySDC.Step 

216 The current step. 

217 """ 

218 

219 L = S.levels[0] 

220 

221 if self.status.t_switch is None: 

222 L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt_initial 

223 

224 super().post_step_processing(controller, S, **kwargs) 

225 

226 @staticmethod 

227 def log_event_time(controller_hooks, process, time, level, sweep, t_switch): 

228 """ 

229 Logs the event time of an event satisfying an appropriate criterion, e.g., event is already resolved well, 

230 event time satisfies tolerance. 

231 

232 Parameters 

233 ---------- 

234 controller_hooks : pySDC.Controller.hooks 

235 Controller with access to the hooks. 

236 process : int 

237 Process for logging. 

238 time : float 

239 Time at which the event time is logged (denotes the current step). 

240 level : int 

241 Level at which event is found. 

242 sweep : int 

243 Denotes the number of sweep. 

244 t_switch : float 

245 Event time founded by switch estimation. 

246 """ 

247 

248 controller_hooks.add_to_stats( 

249 process=process, 

250 time=time, 

251 level=level, 

252 iter=0, 

253 sweep=sweep, 

254 type='switch', 

255 value=t_switch, 

256 ) 

257 

258 @staticmethod 

259 def logging_during_estimation(controller_hooks, process, time, level, sweep, t_switch, state_function): 

260 controller_hooks.add_to_stats( 

261 process=process, 

262 time=time, 

263 level=level, 

264 iter=0, 

265 sweep=sweep, 

266 type='switch_all', 

267 value=t_switch, 

268 ) 

269 controller_hooks.add_to_stats( 

270 process=process, 

271 time=time, 

272 level=level, 

273 iter=0, 

274 sweep=sweep, 

275 type='h_all', 

276 value=max([abs(item) for item in state_function]), 

277 ) 

278 

279 @staticmethod 

280 def get_switch(t_interp, state_function, m_guess): 

281 r""" 

282 Routine to do the interpolation and root finding stuff. 

283 

284 Parameters 

285 ---------- 

286 t_interp : list 

287 Collocation nodes in a step. 

288 state_function : list 

289 Contains values of state function at these collocation nodes. 

290 m_guess : float 

291 Index at which the difference drops below zero. 

292 

293 Returns 

294 ------- 

295 t_switch : float 

296 Time point of found event. 

297 """ 

298 

299 LagrangeInterpolation = LagrangeApproximation(points=t_interp, fValues=state_function) 

300 p = lambda t: LagrangeInterpolation.__call__(t) 

301 

302 def fprime(t): 

303 r""" 

304 Computes the derivative of the scalar interpolant using finite difference. Here, 

305 the derivative is approximated by the backward difference: 

306 

307 .. math:: 

308 \frac{dp}{dt} \approx \frac{25 p(t) - 48 p(t - h) + 36 p(t - 2 h) - 16 p(t - 3h) + 3 p(t - 4 h)}{12 h} 

309 

310 

311 Parameters 

312 ---------- 

313 t : float 

314 Time where the derivatives is computed. 

315 

316 Returns 

317 ------- 

318 dp : float 

319 Derivative of interpolation p at time t. 

320 """ 

321 

322 dt_FD = 1e-10 

323 dp = ( 

324 25 * p(t) - 48 * p(t - dt_FD) + 36 * p(t - 2 * dt_FD) - 16 * p(t - 3 * dt_FD) + 3 * p(t - 4 * dt_FD) 

325 ) / (12 * dt_FD) 

326 return dp 

327 

328 newton_tol, newton_maxiter = 1e-14, 100 

329 t_switch = newton(t_interp[m_guess], p, fprime, newton_tol, newton_maxiter) 

330 return t_switch 

331 

332 @staticmethod 

333 def adapt_interpolation_info(t, left_is_node, t_interp, state_function): 

334 """ 

335 Adapts the x- and y-axis for interpolation. For SDC, it is proven whether the left boundary is a 

336 collocation node or not. In case it is, the first entry of the state function has to be removed, 

337 because it would otherwise contain double values on starting time and the first node. Otherwise, 

338 starting time L.time has to be added to t_interp to also take this value in the interpolation 

339 into account. 

340 

341 Parameters 

342 ---------- 

343 t : float 

344 Starting time of the step. 

345 left_is_node : bool 

346 Indicates whether the left boundary is a collocation node or not. 

347 t_interp : list 

348 x-values for interpolation containing collocation nodes. 

349 state_function : list 

350 y-values for interpolation containing values of state function. 

351 

352 Returns 

353 ------- 

354 t_interp : list 

355 Adapted x-values for interpolation containing collocation nodes. 

356 state_function : list 

357 Adapted y-values for interpolation containing values of state function. 

358 """ 

359 

360 if not left_is_node: 

361 t_interp.insert(0, t) 

362 else: 

363 del state_function[0] 

364 

365 return t_interp, state_function 

366 

367 

368def newton(x0, p, fprime, newton_tol, newton_maxiter): 

369 """ 

370 Newton's method fo find the root of interpolant p. 

371 

372 Parameters 

373 ---------- 

374 x0 : float 

375 Initial guess. 

376 p : callable 

377 Interpolated function where Newton's method is applied at. 

378 fprime : callable 

379 Approximated derivative of p using finite differences. 

380 newton_tol : float 

381 Tolerance for termination. 

382 newton_maxiter : int 

383 Maximum of iterations the method should execute. 

384 

385 Returns 

386 ------- 

387 root : float 

388 Root of function p. 

389 """ 

390 

391 n = 0 

392 while n < newton_maxiter: 

393 res = abs(p(x0)) 

394 if res < newton_tol or np.isnan(p(x0)) and np.isnan(fprime(x0)) or np.isclose(fprime(x0), 0.0): 

395 break 

396 

397 x0 -= 1.0 / fprime(x0) * p(x0) 

398 

399 n += 1 

400 

401 if n == newton_maxiter: 

402 msg = f'Newton did not converge after {n} iterations, error is {res}' 

403 else: 

404 msg = f'Newton did converge after {n} iterations, error for root {x0} is {res}' 

405 print(msg) 

406 

407 root = x0 

408 return root