Coverage for pySDC/projects/PinTSimE/switch_estimator.py: 99%
105 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 16:55 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 16:55 +0000
1import numpy as np
2import scipy as sp
4from pySDC.core.errors import ParameterError
5from pySDC.core.collocation import CollBase
6from pySDC.core.convergence_controller import ConvergenceController, Status
7from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
8from qmat.lagrange import LagrangeApproximation
11class SwitchEstimator(ConvergenceController):
12 """
13 Class to predict the time point of the event and setting a new step size. For the first time, this is a nonMPI version,
14 because a MPI version is not yet developed.
15 """
17 def setup(self, controller, params, description):
18 r"""
19 Function sets default variables to handle with the event at the beginning. The default params are:
21 - control_order : controls the order of the SE's call of convergence controllers.
22 - coll.nodes : defines the collocation nodes for interpolation.
23 - tol_zero : inner tolerance for SE; state function has to satisfy it to terminate.
24 - t_interp : interpolation axis with time points.
25 - state_function : List of values from state function.
27 Parameters
28 ----------
29 controller : pySDC.Controller
30 The controller doing all the stuff in a computation.
31 params : dict
32 The parameters passed for this specific convergence controller.
33 description : dict
34 The description object used to instantiate the controller.
36 Returns
37 -------
38 convergence_controller_params : dict
39 The updated params dictionary.
40 """
42 # for RK4 sweeper, sweep.coll.nodes now consists of values of ButcherTableau
43 # for this reason, collocation nodes will be generated here
44 coll = CollBase(
45 num_nodes=description['sweeper_params']['num_nodes'],
46 quad_type=description['sweeper_params']['quad_type'],
47 )
49 defaults = {
50 'control_order': 0,
51 'nodes': coll.nodes,
52 'tol_zero': 2.5e-12,
53 't_interp': [],
54 'state_function': [],
55 }
56 return {**defaults, **params}
58 def setup_status_variables(self, controller, **kwargs):
59 """
60 Adds switching specific variables to status variables.
62 Parameters
63 ----------
64 controller : pySDC.Controller
65 The controller doing all the stuff in a computation.
66 """
68 self.status = Status(['is_zero', 'switch_detected', 't_switch'])
70 def reset_status_variables(self, controller, **kwargs):
71 """
72 Resets status variables.
74 Parameters
75 ----------
76 controller : pySDC.Controller
77 The controller doing all the stuff in a computation.
78 """
80 self.setup_status_variables(controller, **kwargs)
82 def get_new_step_size(self, controller, S, **kwargs):
83 """
84 Determine a new step size when an event is found such that the event occurs at the time step.
86 Parameters
87 ----------
88 controller : pySDC.Controller
89 The controller doing all the stuff in a computation.
90 S : pySDC.Step
91 The current step.
92 """
94 L = S.levels[0]
96 if CheckConvergence.check_convergence(S):
97 self.status.switch_detected, m_guess, self.params.state_function = L.prob.get_switching_info(L.u, L.time)
99 if self.status.switch_detected:
100 self.params.t_interp = [L.time + L.dt * self.params.nodes[m] for m in range(len(self.params.nodes))]
101 self.params.t_interp, self.params.state_function = self.adapt_interpolation_info(
102 L.time, L.sweep.coll.left_is_node, self.params.t_interp, self.params.state_function
103 )
105 # when the state function is already close to zero the event is already resolved well
106 if (
107 abs(self.params.state_function[-1]) <= self.params.tol_zero
108 or abs(self.params.state_function[0]) <= self.params.tol_zero
109 ):
110 if abs(self.params.state_function[0]) <= self.params.tol_zero:
111 t_switch = self.params.t_interp[0]
112 boundary = 'left'
113 elif abs(self.params.state_function[-1]) <= self.params.tol_zero:
114 boundary = 'right'
115 t_switch = self.params.t_interp[-1]
117 msg = f"The value of state function is close to zero, thus event time is already close enough to the {boundary} end point!"
118 self.log(msg, S)
119 self.log_event_time(
120 controller.hooks[0], S.status.slot, L.time, L.level_index, L.status.sweep, t_switch
121 )
123 L.prob.count_switches()
124 self.status.is_zero = True
126 # intermediate value theorem states that a root is contained in current step
127 if self.params.state_function[0] * self.params.state_function[-1] < 0 and self.status.is_zero is None:
128 self.status.t_switch = self.get_switch(self.params.t_interp, self.params.state_function, m_guess)
130 self.logging_during_estimation(
131 controller.hooks[0],
132 S.status.slot,
133 L.time,
134 L.level_index,
135 L.status.sweep,
136 self.status.t_switch,
137 self.params.state_function,
138 )
140 if L.time < self.status.t_switch < L.time + L.dt:
141 dt_switch = (self.status.t_switch - L.time) * self.params.alpha
143 if (
144 abs(self.status.t_switch - L.time) <= self.params.tol
145 or abs((L.time + L.dt) - self.status.t_switch) <= self.params.tol
146 ):
147 self.log(f"Switch located at time {self.status.t_switch:.15f}", S)
148 L.prob.t_switch = self.status.t_switch
149 self.log_event_time(
150 controller.hooks[0],
151 S.status.slot,
152 L.time,
153 L.level_index,
154 L.status.sweep,
155 self.status.t_switch,
156 )
158 L.prob.count_switches()
160 else:
161 self.log(f"Located Switch at time {self.status.t_switch:.15f} is outside the range", S)
163 # when an event is found, step size matching with this event should be preferred
164 dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt
165 if self.status.switch_detected:
166 L.status.dt_new = dt_switch
167 else:
168 L.status.dt_new = min([dt_planned, dt_switch])
170 else:
171 # event occurs on L.time or L.time + L.dt; no restart necessary
172 boundary = 'left boundary' if self.status.t_switch == L.time else 'right boundary'
173 self.log(f"Estimated switch {self.status.t_switch:.15f} occurs at {boundary}", S)
174 self.log_event_time(
175 controller.hooks[0],
176 S.status.slot,
177 L.time,
178 L.level_index,
179 L.status.sweep,
180 self.status.t_switch,
181 )
182 L.prob.count_switches()
183 self.status.switch_detected = False
185 else: # intermediate value theorem is not satisfied
186 self.status.switch_detected = False
188 def determine_restart(self, controller, S, **kwargs):
189 """
190 Check if the step needs to be restarted due to a predicting switch.
192 Parameters
193 ----------
194 controller : pySDC.Controller
195 The controller doing all the stuff in a computation.
196 S : pySDC.Step
197 The current step.
198 """
200 if self.status.switch_detected:
201 S.status.restart = True
202 S.status.force_done = True
204 super().determine_restart(controller, S, **kwargs)
206 def post_step_processing(self, controller, S, **kwargs):
207 """
208 After a step is done, some variables will be prepared for predicting a possibly new switch.
209 If no Adaptivity is used, the next time step will be set as the default one from the front end.
211 Parameters
212 ----------
213 controller : pySDC.Controller
214 The controller doing all the stuff in a computation.
215 S : pySDC.Step
216 The current step.
217 """
219 L = S.levels[0]
221 if self.status.t_switch is None:
222 L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt_initial
224 super().post_step_processing(controller, S, **kwargs)
226 @staticmethod
227 def log_event_time(controller_hooks, process, time, level, sweep, t_switch):
228 """
229 Logs the event time of an event satisfying an appropriate criterion, e.g., event is already resolved well,
230 event time satisfies tolerance.
232 Parameters
233 ----------
234 controller_hooks : pySDC.Controller.hooks
235 Controller with access to the hooks.
236 process : int
237 Process for logging.
238 time : float
239 Time at which the event time is logged (denotes the current step).
240 level : int
241 Level at which event is found.
242 sweep : int
243 Denotes the number of sweep.
244 t_switch : float
245 Event time founded by switch estimation.
246 """
248 controller_hooks.add_to_stats(
249 process=process,
250 time=time,
251 level=level,
252 iter=0,
253 sweep=sweep,
254 type='switch',
255 value=t_switch,
256 )
258 @staticmethod
259 def logging_during_estimation(controller_hooks, process, time, level, sweep, t_switch, state_function):
260 controller_hooks.add_to_stats(
261 process=process,
262 time=time,
263 level=level,
264 iter=0,
265 sweep=sweep,
266 type='switch_all',
267 value=t_switch,
268 )
269 controller_hooks.add_to_stats(
270 process=process,
271 time=time,
272 level=level,
273 iter=0,
274 sweep=sweep,
275 type='h_all',
276 value=max([abs(item) for item in state_function]),
277 )
279 @staticmethod
280 def get_switch(t_interp, state_function, m_guess):
281 r"""
282 Routine to do the interpolation and root finding stuff.
284 Parameters
285 ----------
286 t_interp : list
287 Collocation nodes in a step.
288 state_function : list
289 Contains values of state function at these collocation nodes.
290 m_guess : float
291 Index at which the difference drops below zero.
293 Returns
294 -------
295 t_switch : float
296 Time point of found event.
297 """
299 LagrangeInterpolation = LagrangeApproximation(points=t_interp, fValues=state_function)
301 def p(t):
302 return LagrangeInterpolation.__call__(t)
304 def fprime(t):
305 r"""
306 Computes the derivative of the scalar interpolant using finite difference. Here,
307 the derivative is approximated by the backward difference:
309 .. math::
310 \frac{dp}{dt} \approx \frac{25 p(t) - 48 p(t - h) + 36 p(t - 2 h) - 16 p(t - 3h) + 3 p(t - 4 h)}{12 h}
313 Parameters
314 ----------
315 t : float
316 Time where the derivatives is computed.
318 Returns
319 -------
320 dp : float
321 Derivative of interpolation p at time t.
322 """
324 dt_FD = 1e-10
325 dp = (
326 25 * p(t) - 48 * p(t - dt_FD) + 36 * p(t - 2 * dt_FD) - 16 * p(t - 3 * dt_FD) + 3 * p(t - 4 * dt_FD)
327 ) / (12 * dt_FD)
328 return dp
330 newton_tol, newton_maxiter = 1e-14, 100
331 t_switch = newton(t_interp[m_guess], p, fprime, newton_tol, newton_maxiter)
332 return t_switch
334 @staticmethod
335 def adapt_interpolation_info(t, left_is_node, t_interp, state_function):
336 """
337 Adapts the x- and y-axis for interpolation. For SDC, it is proven whether the left boundary is a
338 collocation node or not. In case it is, the first entry of the state function has to be removed,
339 because it would otherwise contain double values on starting time and the first node. Otherwise,
340 starting time L.time has to be added to t_interp to also take this value in the interpolation
341 into account.
343 Parameters
344 ----------
345 t : float
346 Starting time of the step.
347 left_is_node : bool
348 Indicates whether the left boundary is a collocation node or not.
349 t_interp : list
350 x-values for interpolation containing collocation nodes.
351 state_function : list
352 y-values for interpolation containing values of state function.
354 Returns
355 -------
356 t_interp : list
357 Adapted x-values for interpolation containing collocation nodes.
358 state_function : list
359 Adapted y-values for interpolation containing values of state function.
360 """
362 if not left_is_node:
363 t_interp.insert(0, t)
364 else:
365 del state_function[0]
367 return t_interp, state_function
370def newton(x0, p, fprime, newton_tol, newton_maxiter):
371 """
372 Newton's method fo find the root of interpolant p.
374 Parameters
375 ----------
376 x0 : float
377 Initial guess.
378 p : callable
379 Interpolated function where Newton's method is applied at.
380 fprime : callable
381 Approximated derivative of p using finite differences.
382 newton_tol : float
383 Tolerance for termination.
384 newton_maxiter : int
385 Maximum of iterations the method should execute.
387 Returns
388 -------
389 root : float
390 Root of function p.
391 """
393 n = 0
394 while n < newton_maxiter:
395 res = abs(p(x0))
396 if res < newton_tol or np.isnan(p(x0)) and np.isnan(fprime(x0)) or np.isclose(fprime(x0), 0.0):
397 break
399 x0 -= 1.0 / fprime(x0) * p(x0)
401 n += 1
403 if n == newton_maxiter:
404 msg = f'Newton did not converge after {n} iterations, error is {res}'
405 else:
406 msg = f'Newton did converge after {n} iterations, error for root {x0} is {res}'
407 print(msg)
409 root = x0
410 return root