Coverage for pySDC/implementations/hooks/log_solution.py: 85%

122 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-18 08:18 +0000

1from pySDC.core.hooks import Hooks 

2import pickle 

3import os 

4import numpy as np 

5from pySDC.helpers.fieldsIO import FieldsIO 

6from pySDC.core.errors import DataError 

7 

8 

9class LogSolution(Hooks): 

10 """ 

11 Store the solution at the end of each step as "u". 

12 """ 

13 

14 def post_step(self, step, level_number): 

15 """ 

16 Record solution at the end of the step 

17 

18 Args: 

19 step (pySDC.Step.step): the current step 

20 level_number (int): the current level number 

21 

22 Returns: 

23 None 

24 """ 

25 super().post_step(step, level_number) 

26 

27 L = step.levels[level_number] 

28 L.sweep.compute_end_point() 

29 

30 self.add_to_stats( 

31 process=step.status.slot, 

32 time=L.time + L.dt, 

33 level=L.level_index, 

34 iter=step.status.iter, 

35 sweep=L.status.sweep, 

36 type='u', 

37 value=L.uend, 

38 ) 

39 

40 

41class LogSolutionAfterIteration(Hooks): 

42 """ 

43 Store the solution at the end of each iteration as "u". 

44 """ 

45 

46 def post_iteration(self, step, level_number): 

47 """ 

48 Record solution at the end of the iteration 

49 

50 Args: 

51 step (pySDC.Step.step): the current step 

52 level_number (int): the current level number 

53 

54 Returns: 

55 None 

56 """ 

57 super().post_iteration(step, level_number) 

58 

59 L = step.levels[level_number] 

60 L.sweep.compute_end_point() 

61 

62 self.add_to_stats( 

63 process=step.status.slot, 

64 time=L.time + L.dt, 

65 level=L.level_index, 

66 iter=step.status.iter, 

67 sweep=L.status.sweep, 

68 type='u', 

69 value=L.uend, 

70 ) 

71 

72 

73class LogToPickleFile(Hooks): 

74 r""" 

75 Hook for logging the solution to file after the step using pickle. 

76 

77 Please configure the hook to your liking by manipulating class attributes. 

78 You must set a custom path to a directory like so: 

79 

80 ``` 

81 LogToFile.path = '/my/directory/' 

82 ``` 

83 

84 Keep in mind that the hook will overwrite files without warning! 

85 You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the 

86 index associated with individual files by giving a different function ``format_index`` class attribute. This should 

87 accept one index and return one string. 

88 

89 You can also give a custom ``logging_condition`` function, accepting the current level if you want to log selectively. 

90 

91 Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution. 

92 Of course, if you are not using numpy, you need to change this. Again, this is a function accepting the level. 

93 

94 After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to 

95 directly load the solution at a given index. Just configure the hook like you did when you recorded the data 

96 beforehand. 

97 

98 Finally, be aware that using this hook with MPI parallel runs may lead to different tasks overwriting files. Make 

99 sure to give a different `file_name` for each task that writes files. 

100 """ 

101 

102 path = None 

103 file_name = 'solution' 

104 counter = 0 

105 

106 def logging_condition(L): 

107 return True 

108 

109 def process_solution(L): 

110 return {'t': L.time + L.dt, 'u': L.uend.view(np.ndarray)} 

111 

112 def format_index(index): 

113 return f'{index:06d}' 

114 

115 def __init__(self): 

116 super().__init__() 

117 

118 if self.path is None: 

119 raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!') 

120 

121 if os.path.isfile(self.path): 

122 raise ValueError( 

123 f'{self.path!r} is not a valid path to log to because a file of the same name exists. Please supply a directory' 

124 ) 

125 

126 if not os.path.isdir(self.path): 

127 os.makedirs(self.path, exist_ok=True) 

128 

129 def log_to_file(self, step, level_number, condition, process_solution=None): 

130 if level_number > 0: 

131 return None 

132 

133 L = step.levels[level_number] 

134 

135 if condition: 

136 path = self.get_path(self.counter) 

137 

138 if process_solution: 

139 data = process_solution(L) 

140 else: 

141 data = type(self).process_solution(L) 

142 

143 with open(path, 'wb') as file: 

144 pickle.dump(data, file) 

145 self.logger.info(f'Stored file {path!r}') 

146 

147 type(self).counter += 1 

148 

149 def post_step(self, step, level_number): 

150 L = step.levels[level_number] 

151 self.log_to_file(step, level_number, type(self).logging_condition(L)) 

152 

153 def pre_run(self, step, level_number): 

154 L = step.levels[level_number] 

155 L.uend = L.u[0] 

156 

157 def process_solution(L): 

158 return { 

159 **type(self).process_solution(L), 

160 't': L.time, 

161 } 

162 

163 self.log_to_file(step, level_number, True, process_solution=process_solution) 

164 

165 @classmethod 

166 def get_path(cls, index): 

167 return f'{cls.path}/{cls.file_name}_{cls.format_index(index)}.pickle' 

168 

169 @classmethod 

170 def load(cls, index): 

171 path = cls.get_path(index) 

172 with open(path, 'rb') as file: 

173 return pickle.load(file) 

174 

175 

176class LogToPickleFileAfterXS(LogToPickleFile): 

177 r''' 

178 Log to file after certain amount of time has passed instead of after every step 

179 ''' 

180 

181 time_increment = 0 

182 t_next_log = 0 

183 

184 def post_step(self, step, level_number): 

185 L = step.levels[level_number] 

186 

187 if self.t_next_log == 0: 

188 self.t_next_log = self.time_increment 

189 

190 if L.time + L.dt >= self.t_next_log and not step.status.restart: 

191 super().post_step(step, level_number) 

192 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment 

193 

194 def pre_run(self, step, level_number): 

195 L = step.levels[level_number] 

196 L.uend = L.u[0] 

197 

198 def process_solution(L): 

199 return { 

200 **type(self).process_solution(L), 

201 't': L.time, 

202 } 

203 

204 self.log_to_file(step, level_number, type(self).logging_condition(L), process_solution=process_solution) 

205 

206 

207class LogToFile(Hooks): 

208 filename = 'myRun.pySDC' 

209 time_increment = 0 

210 allow_overwriting = False 

211 

212 def __init__(self): 

213 super().__init__() 

214 self.outfile = None 

215 self.t_next_log = 0 

216 FieldsIO.ALLOW_OVERWRITE = self.allow_overwriting 

217 

218 def pre_run(self, step, level_number): 

219 if level_number > 0: 

220 return None 

221 L = step.levels[level_number] 

222 

223 # setup outfile 

224 if os.path.isfile(self.filename) and L.time > 0: 

225 L.prob.setUpFieldsIO() 

226 self.outfile = FieldsIO.fromFile(self.filename) 

227 self.logger.info( 

228 f'Set up file {self.filename!r} for writing output. This file already contains solutions up to t={self.outfile.times[-1]:.4f}.' 

229 ) 

230 else: 

231 self.outfile = L.prob.getOutputFile(self.filename) 

232 self.logger.info(f'Set up file {self.filename!r} for writing output.') 

233 

234 # write initial conditions 

235 if L.time not in self.outfile.times: 

236 self.outfile.addField(time=L.time, field=L.prob.processSolutionForOutput(L.u[0])) 

237 self.logger.info(f'Written initial conditions at t={L.time:4f} to file') 

238 

239 def post_step(self, step, level_number): 

240 if level_number > 0: 

241 return None 

242 

243 L = step.levels[level_number] 

244 

245 if self.t_next_log == 0: 

246 self.t_next_log = L.time + self.time_increment 

247 

248 if L.time + L.dt >= self.t_next_log and not step.status.restart: 

249 value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times] 

250 if value_exists and not self.allow_overwriting: 

251 raise DataError(f'Already have recorded data for time {L.time + L.dt} in this file!') 

252 self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend)) 

253 self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file') 

254 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment 

255 

256 @classmethod 

257 def load(cls, index): 

258 data = {} 

259 file = FieldsIO.fromFile(cls.filename) 

260 file_entry = file.readField(idx=index) 

261 data['u'] = file_entry[1] 

262 data['t'] = file_entry[0] 

263 return data