Coverage for pySDC/implementations/hooks/log_solution.py: 83%

137 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-04 15:08 +0000

1from pySDC.core.hooks import Hooks 

2import pickle 

3import os 

4import numpy as np 

5from pySDC.helpers.fieldsIO import FieldsIO 

6from pySDC.core.errors import DataError 

7 

8 

9class LogSolution(Hooks): 

10 """ 

11 Store the solution at the end of each step as "u". 

12 """ 

13 

14 def post_step(self, step, level_number): 

15 """ 

16 Record solution at the end of the step 

17 

18 Args: 

19 step (pySDC.Step.step): the current step 

20 level_number (int): the current level number 

21 

22 Returns: 

23 None 

24 """ 

25 super().post_step(step, level_number) 

26 

27 L = step.levels[level_number] 

28 L.sweep.compute_end_point() 

29 

30 self.add_to_stats( 

31 process=step.status.slot, 

32 time=L.time + L.dt, 

33 level=L.level_index, 

34 iter=step.status.iter, 

35 sweep=L.status.sweep, 

36 type='u', 

37 value=L.uend, 

38 ) 

39 

40 

41class LogSolutionAfterIteration(Hooks): 

42 """ 

43 Store the solution at the end of each iteration as "u". 

44 """ 

45 

46 def post_iteration(self, step, level_number): 

47 """ 

48 Record solution at the end of the iteration 

49 

50 Args: 

51 step (pySDC.Step.step): the current step 

52 level_number (int): the current level number 

53 

54 Returns: 

55 None 

56 """ 

57 super().post_iteration(step, level_number) 

58 

59 L = step.levels[level_number] 

60 L.sweep.compute_end_point() 

61 

62 self.add_to_stats( 

63 process=step.status.slot, 

64 time=L.time + L.dt, 

65 level=L.level_index, 

66 iter=step.status.iter, 

67 sweep=L.status.sweep, 

68 type='u', 

69 value=L.uend, 

70 ) 

71 

72 

73class LogToPickleFile(Hooks): 

74 r""" 

75 Hook for logging the solution to file after the step using pickle. 

76 

77 Please configure the hook to your liking by manipulating class attributes. 

78 You must set a custom path to a directory like so: 

79 

80 ``` 

81 LogToFile.path = '/my/directory/' 

82 ``` 

83 

84 Keep in mind that the hook will overwrite files without warning! 

85 You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the 

86 index associated with individual files by giving a different function ``format_index`` class attribute. This should 

87 accept one index and return one string. 

88 

89 You can also give a custom ``logging_condition`` function, accepting the current level if you want to log selectively. 

90 

91 Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution. 

92 Of course, if you are not using numpy, you need to change this. Again, this is a function accepting the level. 

93 

94 After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to 

95 directly load the solution at a given index. Just configure the hook like you did when you recorded the data 

96 beforehand. 

97 

98 Finally, be aware that using this hook with MPI parallel runs may lead to different tasks overwriting files. Make 

99 sure to give a different `file_name` for each task that writes files. 

100 """ 

101 

102 path = None 

103 file_name = 'solution' 

104 counter = 0 

105 

106 def logging_condition(L): 

107 return True 

108 

109 def process_solution(L): 

110 return {'t': L.time + L.dt, 'u': L.uend.view(np.ndarray)} 

111 

112 def format_index(index): 

113 return f'{index:06d}' 

114 

115 def __init__(self): 

116 super().__init__() 

117 

118 if self.path is None: 

119 raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!') 

120 

121 if os.path.isfile(self.path): 

122 raise ValueError( 

123 f'{self.path!r} is not a valid path to log to because a file of the same name exists. Please supply a directory' 

124 ) 

125 

126 if not os.path.isdir(self.path): 

127 os.makedirs(self.path, exist_ok=True) 

128 

129 def log_to_file(self, step, level_number, condition, process_solution=None): 

130 if level_number > 0: 

131 return None 

132 

133 L = step.levels[level_number] 

134 

135 if condition: 

136 path = self.get_path(self.counter) 

137 

138 if process_solution: 

139 data = process_solution(L) 

140 else: 

141 data = type(self).process_solution(L) 

142 

143 with open(path, 'wb') as file: 

144 pickle.dump(data, file) 

145 self.logger.info(f'Stored file {path!r}') 

146 

147 type(self).counter += 1 

148 

149 def post_step(self, step, level_number): 

150 L = step.levels[level_number] 

151 self.log_to_file(step, level_number, type(self).logging_condition(L)) 

152 

153 def pre_run(self, step, level_number): 

154 L = step.levels[level_number] 

155 L.uend = L.u[0] 

156 

157 def process_solution(L): 

158 return { 

159 **type(self).process_solution(L), 

160 't': L.time, 

161 } 

162 

163 self.log_to_file(step, level_number, True, process_solution=process_solution) 

164 

165 @classmethod 

166 def get_path(cls, index): 

167 return f'{cls.path}/{cls.file_name}_{cls.format_index(index)}.pickle' 

168 

169 @classmethod 

170 def load(cls, index): 

171 path = cls.get_path(index) 

172 with open(path, 'rb') as file: 

173 return pickle.load(file) 

174 

175 

176class LogToPickleFileAfterXS(LogToPickleFile): 

177 r''' 

178 Log to file after certain amount of time has passed instead of after every step 

179 ''' 

180 

181 time_increment = 0 

182 t_next_log = 0 

183 

184 def post_step(self, step, level_number): 

185 L = step.levels[level_number] 

186 

187 if self.t_next_log == 0: 

188 self.t_next_log = self.time_increment 

189 

190 if L.time + L.dt >= self.t_next_log and not step.status.restart: 

191 super().post_step(step, level_number) 

192 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment 

193 

194 def pre_run(self, step, level_number): 

195 L = step.levels[level_number] 

196 L.uend = L.u[0] 

197 

198 def process_solution(L): 

199 return { 

200 **type(self).process_solution(L), 

201 't': L.time, 

202 } 

203 

204 self.log_to_file(step, level_number, type(self).logging_condition(L), process_solution=process_solution) 

205 

206 

207class LogToFile(Hooks): 

208 filename = 'myRun.pySDC' 

209 time_increment = 0 

210 allow_overwriting = False 

211 counter = 0 # number of stored time points in the file 

212 

213 def __init__(self): 

214 super().__init__() 

215 self.outfile = None 

216 self.t_next_log = 0 

217 FieldsIO.ALLOW_OVERWRITE = self.allow_overwriting 

218 

219 def pre_run(self, step, level_number): 

220 if level_number > 0: 

221 return None 

222 L = step.levels[level_number] 

223 

224 # setup outfile 

225 if os.path.isfile(self.filename) and L.time > 0: 

226 L.prob.setUpFieldsIO() 

227 self.outfile = FieldsIO.fromFile(self.filename) 

228 self.counter = len(self.outfile.times) 

229 self.logger.info( 

230 f'Set up file {self.filename!r} for writing output. This file already contains {self.counter} solutions up to t={self.outfile.times[-1]:.4f}.' 

231 ) 

232 else: 

233 self.outfile = L.prob.getOutputFile(self.filename) 

234 self.logger.info(f'Set up file {self.filename!r} for writing output.') 

235 

236 # write initial conditions 

237 if L.time not in self.outfile.times: 

238 self.outfile.addField(time=L.time, field=L.prob.processSolutionForOutput(L.u[0])) 

239 self.logger.info(f'Written initial conditions at t={L.time:4f} to file') 

240 

241 type(self).counter = len(self.outfile.times) 

242 self.logger.info(f'Will write to disk every {self.time_increment:.4e} time units') 

243 

244 def post_step(self, step, level_number): 

245 if level_number > 0: 

246 return None 

247 

248 L = step.levels[level_number] 

249 

250 if self.t_next_log == 0: 

251 self.t_next_log = L.time + self.time_increment 

252 

253 if L.time + L.dt >= self.t_next_log and not step.status.restart: 

254 value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times] 

255 if value_exists and not self.allow_overwriting: 

256 raise DataError(f'Already have recorded data for time {L.time + L.dt} in this file!') 

257 self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend)) 

258 self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file') 

259 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment 

260 type(self).counter = len(self.outfile.times) 

261 

262 def post_run(self, step, level_number): 

263 if level_number > 0: 

264 return None 

265 

266 L = step.levels[level_number] 

267 

268 value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times] 

269 if not value_exists: 

270 self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend)) 

271 self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file') 

272 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment 

273 type(self).counter = len(self.outfile.times) 

274 

275 @classmethod 

276 def load(cls, index): 

277 data = {} 

278 file = FieldsIO.fromFile(cls.filename) 

279 file_entry = file.readField(idx=index) 

280 data['u'] = file_entry[1] 

281 data['t'] = file_entry[0] 

282 return data