Coverage for pySDC/projects/soft_failure/implicit_sweeper_faults.py: 97%

144 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import struct 

2from datetime import datetime 

3 

4import numpy as np 

5 

6from pySDC.helpers.pysdc_helper import FrozenClass 

7from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

8 

9 

10class _fault_stats(FrozenClass): 

11 def __init__(self): 

12 self.nfaults_called = 0 

13 self.nfaults_injected_u = 0 

14 self.nfaults_injected_f = 0 

15 self.nfaults_detected = 0 

16 self.ncorrection_attempts = 0 

17 self.nfaults_missed = 0 

18 self.nfalse_positives = 0 

19 self.nfalse_positives_in_correction = 0 

20 self.nclean_steps = 0 

21 

22 self._freeze() 

23 

24 

25class implicit_sweeper_faults(generic_implicit): 

26 """ 

27 LU sweeper using LU decomposition of the Q matrix for the base integrator, special type of generic implicit sweeper 

28 

29 """ 

30 

31 def __init__(self, params): 

32 """ 

33 Initialization routine for the custom sweeper 

34 

35 Args: 

36 params: parameters for the sweeper 

37 """ 

38 

39 if 'allow_fault_correction' not in params: 

40 params['allow_fault_correction'] = False 

41 

42 if 'detector_threshold' not in params: 

43 params['detector_threshold'] = 1.0 

44 

45 if 'dump_injections_filehandle' not in params: 

46 params['dump_injections_filehandle'] = None 

47 

48 # call parent's initialization routine 

49 super(implicit_sweeper_faults, self).__init__(params) 

50 

51 self.fault_stats = _fault_stats() 

52 

53 self.fault_injected = False 

54 self.fault_detected = False 

55 self.in_correction = False 

56 self.fault_iteration = False 

57 

58 def reset_fault_stats(self): 

59 """ 

60 Helper method to reset all fault related stats and flags. Will be called after the run in post-processing. 

61 """ 

62 

63 self.fault_stats = _fault_stats() 

64 self.fault_injected = False 

65 self.fault_detected = False 

66 self.in_correction = False 

67 self.fault_iteration = False 

68 

69 @staticmethod 

70 def bitsToFloat(b): 

71 """ 

72 Static helper method to get a number from bit into float representation 

73 

74 Args: 

75 b: bit representation of a number 

76 

77 Returns: 

78 float representation of b 

79 """ 

80 s = struct.pack('>q', b) 

81 return struct.unpack('>d', s)[0] 

82 

83 @staticmethod 

84 def floatToBits(f): 

85 """ 

86 Static helper method to get a number from float into bit representation 

87 

88 Args: 

89 f: float representation of a number 

90 

91 Returns: 

92 bit representation of f 

93 """ 

94 s = struct.pack('>d', f) 

95 return struct.unpack('>q', s)[0] 

96 

97 def do_bitflip(self, a, pos): 

98 """ 

99 Method to do a bit flip 

100 

101 Args: 

102 a: float representation of a number 

103 pos (int between 0 and 63): position of bit flip 

104 

105 Returns: 

106 float representation of a number after bit flip at pos 

107 """ 

108 # flip of mantissa (fraction) bit (pos between 0 and 51) or of exponent bit (pos between 52 and 62) 

109 if pos < 63: 

110 b = self.floatToBits(a) 

111 # mask: bit representation with 1 at pos and 0 elsewhere 

112 mask = 1 << pos 

113 # ^: bitwise xor-operator --> bit flip at pos 

114 c = b ^ mask 

115 return self.bitsToFloat(c) 

116 # "flip" of sign bit (pos = 63) 

117 elif pos == 63: 

118 return -a 

119 

120 def inject_fault(self, type=None, target=None): 

121 """ 

122 Main method to inject a fault 

123 

124 Args: 

125 type (str): string describing whether u of f should be affected 

126 target: data to be modified 

127 """ 

128 

129 pos = 0 

130 bitflip_entry = 0 

131 

132 # do bitflip in u 

133 if type == 'u': 

134 # do something to target = u here! 

135 # do a bitflip at random vector entry of u at random position in bit representation 

136 ulen = len(target) 

137 bitflip_entry = np.random.randint(ulen) 

138 pos = np.random.randint(64) 

139 tmp = target[bitflip_entry] 

140 target[bitflip_entry] = self.do_bitflip(target[bitflip_entry], pos) 

141 # print(' fault in u injected') 

142 

143 self.fault_stats.nfaults_injected_u += 1 

144 

145 # do bitflip in f 

146 elif type == 'f': 

147 # do something to target = f here! 

148 # do a bitflip at random vector entry of f at random position in bit representation 

149 flen = len(target) 

150 bitflip_entry = np.random.randint(flen) 

151 pos = np.random.randint(64) 

152 tmp = target[bitflip_entry] 

153 target[bitflip_entry] = self.do_bitflip(target[bitflip_entry], pos) 

154 # print(' fault in f injected') 

155 

156 self.fault_stats.nfaults_injected_f += 1 

157 

158 else: 

159 tmp = None 

160 print('ERROR: wrong fault type specified, got %s' % type) 

161 exit() 

162 

163 self.fault_injected = True 

164 

165 if self.params.dump_injections_filehandle is not None: 

166 out = str(datetime.now()) 

167 out += ' --- ' 

168 out += type + ' ' + str(bitflip_entry) + ' ' + str(pos) 

169 out += ' --- ' 

170 out += str(tmp) + ' ' + str(target[bitflip_entry]) + ' ' + str(np.abs(tmp - target[bitflip_entry])) 

171 out += '\n' 

172 self.params.dump_injections_filehandle.write(out) 

173 

174 def detect_fault(self, current_node=None, rhs=None): 

175 """ 

176 Main method to detect a fault 

177 

178 Args: 

179 current_node (int): current node we are working with at the moment 

180 rhs: right-hand side vector for usage in detector 

181 """ 

182 

183 # get current level for further use 

184 L = self.level 

185 

186 # calculate solver residual 

187 res = L.u[current_node] - L.dt * self.QI[current_node, current_node] * L.f[current_node] - rhs 

188 res_norm = np.linalg.norm(res, np.inf) 

189 if np.isnan(res_norm) or res_norm > self.params.detector_threshold: 

190 # print(' FAULT DETECTED!') 

191 self.fault_detected = True 

192 else: 

193 self.fault_detected = False 

194 

195 # update statistics 

196 # fault injected and fault detected -> yeah! 

197 if self.fault_injected and self.fault_detected: 

198 self.fault_stats.nfaults_detected += 1 

199 # no fault injected but fault detected -> meh! 

200 elif not self.fault_injected and self.fault_detected: 

201 self.fault_stats.nfalse_positives += 1 

202 # in correction mode and fault detected -> meeeh! 

203 if self.in_correction: 

204 self.fault_stats.nfalse_positives_in_correction += 1 

205 # fault injected but no fault detected -> meh! 

206 elif self.fault_injected and not self.fault_detected: 

207 self.fault_stats.nfaults_missed += 1 

208 # no fault injected and no fault detected -> yeah! 

209 else: 

210 self.fault_stats.nclean_steps += 1 

211 

212 def correct_fault(self): 

213 """ 

214 Main method to correct a fault or issue a restart 

215 """ 

216 

217 # do correction magic or issue restart here... could be empty! 

218 

219 # we need to make sure that not another fault is injected here.. could also temporarily lower the probability 

220 self.in_correction = True 

221 # print(' doing correction...') 

222 

223 self.fault_stats.ncorrection_attempts += 1 

224 self.fault_detected = False 

225 

226 def update_nodes(self): 

227 """ 

228 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes 

229 

230 Returns: 

231 None 

232 """ 

233 

234 # get current level and problem description 

235 L = self.level 

236 P = L.prob 

237 

238 # only if the level has been touched before 

239 assert L.status.unlocked 

240 

241 # get number of collocation nodes for easier access 

242 M = self.coll.num_nodes 

243 

244 # gather all terms which are known already (e.g. from the previous iteration) 

245 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau 

246 

247 # get QF(u^k) 

248 integral = self.integrate() 

249 for m in range(M): 

250 # get -QdF(u^k)_m 

251 for j in range(M + 1): 

252 integral[m] -= L.dt * self.QI[m + 1, j] * L.f[j] 

253 

254 # add initial value 

255 integral[m] += L.u[0] 

256 # add tau if associated 

257 if L.tau[m] is not None: 

258 integral[m] += L.tau[m] 

259 

260 fault_node = np.random.randint(M) 

261 

262 # do the sweep 

263 m = 0 

264 while m < M: 

265 # see if there will be a fault 

266 self.fault_injected = False 

267 fault_at_u = False 

268 fault_at_f = False 

269 if not self.in_correction and m == fault_node and self.fault_iteration: 

270 if np.random.randint(2) == 0: 

271 fault_at_u = True 

272 else: 

273 fault_at_f = True 

274 

275 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1) 

276 # this is what needs to be protected separately! 

277 rhs = P.dtype_u(integral[m]) 

278 for j in range(m + 1): 

279 rhs += L.dt * self.QI[m + 1, j] * L.f[j] 

280 

281 if fault_at_u: 

282 # implicit solve with prefactor stemming from the diagonal of Qd 

283 L.u[m + 1] = P.solve_system( 

284 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m] 

285 ) 

286 

287 # inject fault at some u value 

288 self.inject_fault(type='u', target=L.u[m + 1]) 

289 

290 # update function values 

291 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

292 

293 elif fault_at_f: 

294 # implicit solve with prefactor stemming from the diagonal of Qd 

295 L.u[m + 1] = P.solve_system( 

296 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m] 

297 ) 

298 

299 # update function values 

300 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

301 

302 # inject fault at some f value 

303 self.inject_fault(type='f', target=L.f[m + 1]) 

304 

305 else: 

306 # implicit solve with prefactor stemming from the diagonal of Qd 

307 L.u[m + 1] = P.solve_system( 

308 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m] 

309 ) 

310 

311 # update function values 

312 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

313 

314 # see if our detector finds something 

315 self.detect_fault(current_node=m + 1, rhs=rhs) 

316 

317 # if we are allowed to try correction, do so, otherwise proceed with sweep 

318 if not self.in_correction and self.fault_detected and self.params.allow_fault_correction: 

319 self.correct_fault() 

320 else: 

321 self.in_correction = False 

322 m += 1 

323 

324 # indicate presence of new values at this level 

325 L.status.updated = True 

326 

327 return None