Coverage for pySDC/projects/soft_failure/implicit_sweeper_faults.py: 94%
144 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import struct
2from datetime import datetime
4import numpy as np
6from pySDC.helpers.pysdc_helper import FrozenClass
7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
10class _fault_stats(FrozenClass):
11 def __init__(self):
12 self.nfaults_called = 0
13 self.nfaults_injected_u = 0
14 self.nfaults_injected_f = 0
15 self.nfaults_detected = 0
16 self.ncorrection_attempts = 0
17 self.nfaults_missed = 0
18 self.nfalse_positives = 0
19 self.nfalse_positives_in_correction = 0
20 self.nclean_steps = 0
22 self._freeze()
25class implicit_sweeper_faults(generic_implicit):
26 """
27 LU sweeper using LU decomposition of the Q matrix for the base integrator, special type of generic implicit sweeper
29 """
31 def __init__(self, params):
32 """
33 Initialization routine for the custom sweeper
35 Args:
36 params: parameters for the sweeper
37 """
39 if 'allow_fault_correction' not in params:
40 params['allow_fault_correction'] = False
42 if 'detector_threshold' not in params:
43 params['detector_threshold'] = 1.0
45 if 'dump_injections_filehandle' not in params:
46 params['dump_injections_filehandle'] = None
48 # call parent's initialization routine
49 super(implicit_sweeper_faults, self).__init__(params)
51 self.fault_stats = _fault_stats()
53 self.fault_injected = False
54 self.fault_detected = False
55 self.in_correction = False
56 self.fault_iteration = False
58 def reset_fault_stats(self):
59 """
60 Helper method to reset all fault related stats and flags. Will be called after the run in post-processing.
61 """
63 self.fault_stats = _fault_stats()
64 self.fault_injected = False
65 self.fault_detected = False
66 self.in_correction = False
67 self.fault_iteration = False
69 @staticmethod
70 def bitsToFloat(b):
71 """
72 Static helper method to get a number from bit into float representation
74 Args:
75 b: bit representation of a number
77 Returns:
78 float representation of b
79 """
80 s = struct.pack('>q', b)
81 return struct.unpack('>d', s)[0]
83 @staticmethod
84 def floatToBits(f):
85 """
86 Static helper method to get a number from float into bit representation
88 Args:
89 f: float representation of a number
91 Returns:
92 bit representation of f
93 """
94 s = struct.pack('>d', f)
95 return struct.unpack('>q', s)[0]
97 def do_bitflip(self, a, pos):
98 """
99 Method to do a bit flip
101 Args:
102 a: float representation of a number
103 pos (int between 0 and 63): position of bit flip
105 Returns:
106 float representation of a number after bit flip at pos
107 """
108 # flip of mantissa (fraction) bit (pos between 0 and 51) or of exponent bit (pos between 52 and 62)
109 if pos < 63:
110 b = self.floatToBits(a)
111 # mask: bit representation with 1 at pos and 0 elsewhere
112 mask = 1 << pos
113 # ^: bitwise xor-operator --> bit flip at pos
114 c = b ^ mask
115 return self.bitsToFloat(c)
116 # "flip" of sign bit (pos = 63)
117 elif pos == 63:
118 return -a
120 def inject_fault(self, type=None, target=None):
121 """
122 Main method to inject a fault
124 Args:
125 type (str): string describing whether u of f should be affected
126 target: data to be modified
127 """
129 pos = 0
130 bitflip_entry = 0
132 # do bitflip in u
133 if type == 'u':
134 # do something to target = u here!
135 # do a bitflip at random vector entry of u at random position in bit representation
136 ulen = len(target)
137 bitflip_entry = np.random.randint(ulen)
138 pos = np.random.randint(64)
139 tmp = target[bitflip_entry]
140 target[bitflip_entry] = self.do_bitflip(target[bitflip_entry], pos)
141 # print(' fault in u injected')
143 self.fault_stats.nfaults_injected_u += 1
145 # do bitflip in f
146 elif type == 'f':
147 # do something to target = f here!
148 # do a bitflip at random vector entry of f at random position in bit representation
149 flen = len(target)
150 bitflip_entry = np.random.randint(flen)
151 pos = np.random.randint(64)
152 tmp = target[bitflip_entry]
153 target[bitflip_entry] = self.do_bitflip(target[bitflip_entry], pos)
154 # print(' fault in f injected')
156 self.fault_stats.nfaults_injected_f += 1
158 else:
159 tmp = None
160 print('ERROR: wrong fault type specified, got %s' % type)
161 exit()
163 self.fault_injected = True
165 if self.params.dump_injections_filehandle is not None:
166 out = str(datetime.now())
167 out += ' --- '
168 out += type + ' ' + str(bitflip_entry) + ' ' + str(pos)
169 out += ' --- '
170 out += str(tmp) + ' ' + str(target[bitflip_entry]) + ' ' + str(np.abs(tmp - target[bitflip_entry]))
171 out += '\n'
172 self.params.dump_injections_filehandle.write(out)
174 def detect_fault(self, current_node=None, rhs=None):
175 """
176 Main method to detect a fault
178 Args:
179 current_node (int): current node we are working with at the moment
180 rhs: right-hand side vector for usage in detector
181 """
183 # get current level for further use
184 L = self.level
186 # calculate solver residual
187 res = L.u[current_node] - L.dt * self.QI[current_node, current_node] * L.f[current_node] - rhs
188 res_norm = np.linalg.norm(res, np.inf)
189 if np.isnan(res_norm) or res_norm > self.params.detector_threshold:
190 # print(' FAULT DETECTED!')
191 self.fault_detected = True
192 else:
193 self.fault_detected = False
195 # update statistics
196 # fault injected and fault detected -> yeah!
197 if self.fault_injected and self.fault_detected:
198 self.fault_stats.nfaults_detected += 1
199 # no fault injected but fault detected -> meh!
200 elif not self.fault_injected and self.fault_detected:
201 self.fault_stats.nfalse_positives += 1
202 # in correction mode and fault detected -> meeeh!
203 if self.in_correction:
204 self.fault_stats.nfalse_positives_in_correction += 1
205 # fault injected but no fault detected -> meh!
206 elif self.fault_injected and not self.fault_detected:
207 self.fault_stats.nfaults_missed += 1
208 # no fault injected and no fault detected -> yeah!
209 else:
210 self.fault_stats.nclean_steps += 1
212 def correct_fault(self):
213 """
214 Main method to correct a fault or issue a restart
215 """
217 # do correction magic or issue restart here... could be empty!
219 # we need to make sure that not another fault is injected here.. could also temporarily lower the probability
220 self.in_correction = True
221 # print(' doing correction...')
223 self.fault_stats.ncorrection_attempts += 1
224 self.fault_detected = False
226 def update_nodes(self):
227 """
228 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
230 Returns:
231 None
232 """
234 # get current level and problem description
235 L = self.level
236 P = L.prob
238 # only if the level has been touched before
239 assert L.status.unlocked
241 # get number of collocation nodes for easier access
242 M = self.coll.num_nodes
244 # gather all terms which are known already (e.g. from the previous iteration)
245 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
247 # get QF(u^k)
248 integral = self.integrate()
249 for m in range(M):
250 # get -QdF(u^k)_m
251 for j in range(M + 1):
252 integral[m] -= L.dt * self.QI[m + 1, j] * L.f[j]
254 # add initial value
255 integral[m] += L.u[0]
256 # add tau if associated
257 if L.tau[m] is not None:
258 integral[m] += L.tau[m]
260 fault_node = np.random.randint(M)
262 # do the sweep
263 m = 0
264 while m < M:
265 # see if there will be a fault
266 self.fault_injected = False
267 fault_at_u = False
268 fault_at_f = False
269 if not self.in_correction and m == fault_node and self.fault_iteration:
270 if np.random.randint(2) == 0:
271 fault_at_u = True
272 else:
273 fault_at_f = True
275 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
276 # this is what needs to be protected separately!
277 rhs = P.dtype_u(integral[m])
278 for j in range(m + 1):
279 rhs += L.dt * self.QI[m + 1, j] * L.f[j]
281 if fault_at_u:
282 # implicit solve with prefactor stemming from the diagonal of Qd
283 L.u[m + 1] = P.solve_system(
284 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m]
285 )
287 # inject fault at some u value
288 self.inject_fault(type='u', target=L.u[m + 1])
290 # update function values
291 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
293 elif fault_at_f:
294 # implicit solve with prefactor stemming from the diagonal of Qd
295 L.u[m + 1] = P.solve_system(
296 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m]
297 )
299 # update function values
300 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
302 # inject fault at some f value
303 self.inject_fault(type='f', target=L.f[m + 1])
305 else:
306 # implicit solve with prefactor stemming from the diagonal of Qd
307 L.u[m + 1] = P.solve_system(
308 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m]
309 )
311 # update function values
312 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
314 # see if our detector finds something
315 self.detect_fault(current_node=m + 1, rhs=rhs)
317 # if we are allowed to try correction, do so, otherwise proceed with sweep
318 if not self.in_correction and self.fault_detected and self.params.allow_fault_correction:
319 self.correct_fault()
320 else:
321 self.in_correction = False
322 m += 1
324 # indicate presence of new values at this level
325 L.status.updated = True
327 return None