Coverage for pySDC/implementations/hooks/log_solution.py: 83%
137 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
1from pySDC.core.hooks import Hooks
2import pickle
3import os
4import numpy as np
5from pySDC.helpers.fieldsIO import FieldsIO
6from pySDC.core.errors import DataError
9class LogSolution(Hooks):
10 """
11 Store the solution at the end of each step as "u".
12 """
14 def post_step(self, step, level_number):
15 """
16 Record solution at the end of the step
18 Args:
19 step (pySDC.Step.step): the current step
20 level_number (int): the current level number
22 Returns:
23 None
24 """
25 super().post_step(step, level_number)
27 L = step.levels[level_number]
28 L.sweep.compute_end_point()
30 self.add_to_stats(
31 process=step.status.slot,
32 time=L.time + L.dt,
33 level=L.level_index,
34 iter=step.status.iter,
35 sweep=L.status.sweep,
36 type='u',
37 value=L.uend,
38 )
41class LogSolutionAfterIteration(Hooks):
42 """
43 Store the solution at the end of each iteration as "u".
44 """
46 def post_iteration(self, step, level_number):
47 """
48 Record solution at the end of the iteration
50 Args:
51 step (pySDC.Step.step): the current step
52 level_number (int): the current level number
54 Returns:
55 None
56 """
57 super().post_iteration(step, level_number)
59 L = step.levels[level_number]
60 L.sweep.compute_end_point()
62 self.add_to_stats(
63 process=step.status.slot,
64 time=L.time + L.dt,
65 level=L.level_index,
66 iter=step.status.iter,
67 sweep=L.status.sweep,
68 type='u',
69 value=L.uend,
70 )
73class LogToPickleFile(Hooks):
74 r"""
75 Hook for logging the solution to file after the step using pickle.
77 Please configure the hook to your liking by manipulating class attributes.
78 You must set a custom path to a directory like so:
80 ```
81 LogToFile.path = '/my/directory/'
82 ```
84 Keep in mind that the hook will overwrite files without warning!
85 You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the
86 index associated with individual files by giving a different function ``format_index`` class attribute. This should
87 accept one index and return one string.
89 You can also give a custom ``logging_condition`` function, accepting the current level if you want to log selectively.
91 Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution.
92 Of course, if you are not using numpy, you need to change this. Again, this is a function accepting the level.
94 After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to
95 directly load the solution at a given index. Just configure the hook like you did when you recorded the data
96 beforehand.
98 Finally, be aware that using this hook with MPI parallel runs may lead to different tasks overwriting files. Make
99 sure to give a different `file_name` for each task that writes files.
100 """
102 path = None
103 file_name = 'solution'
104 counter = 0
106 def logging_condition(L):
107 return True
109 def process_solution(L):
110 return {'t': L.time + L.dt, 'u': L.uend.view(np.ndarray)}
112 def format_index(index):
113 return f'{index:06d}'
115 def __init__(self):
116 super().__init__()
118 if self.path is None:
119 raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!')
121 if os.path.isfile(self.path):
122 raise ValueError(
123 f'{self.path!r} is not a valid path to log to because a file of the same name exists. Please supply a directory'
124 )
126 if not os.path.isdir(self.path):
127 os.makedirs(self.path, exist_ok=True)
129 def log_to_file(self, step, level_number, condition, process_solution=None):
130 if level_number > 0:
131 return None
133 L = step.levels[level_number]
135 if condition:
136 path = self.get_path(self.counter)
138 if process_solution:
139 data = process_solution(L)
140 else:
141 data = type(self).process_solution(L)
143 with open(path, 'wb') as file:
144 pickle.dump(data, file)
145 self.logger.info(f'Stored file {path!r}')
147 type(self).counter += 1
149 def post_step(self, step, level_number):
150 L = step.levels[level_number]
151 self.log_to_file(step, level_number, type(self).logging_condition(L))
153 def pre_run(self, step, level_number):
154 L = step.levels[level_number]
155 L.uend = L.u[0]
157 def process_solution(L):
158 return {
159 **type(self).process_solution(L),
160 't': L.time,
161 }
163 self.log_to_file(step, level_number, True, process_solution=process_solution)
165 @classmethod
166 def get_path(cls, index):
167 return f'{cls.path}/{cls.file_name}_{cls.format_index(index)}.pickle'
169 @classmethod
170 def load(cls, index):
171 path = cls.get_path(index)
172 with open(path, 'rb') as file:
173 return pickle.load(file)
176class LogToPickleFileAfterXS(LogToPickleFile):
177 r'''
178 Log to file after certain amount of time has passed instead of after every step
179 '''
181 time_increment = 0
182 t_next_log = 0
184 def post_step(self, step, level_number):
185 L = step.levels[level_number]
187 if self.t_next_log == 0:
188 self.t_next_log = self.time_increment
190 if L.time + L.dt >= self.t_next_log and not step.status.restart:
191 super().post_step(step, level_number)
192 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
194 def pre_run(self, step, level_number):
195 L = step.levels[level_number]
196 L.uend = L.u[0]
198 def process_solution(L):
199 return {
200 **type(self).process_solution(L),
201 't': L.time,
202 }
204 self.log_to_file(step, level_number, type(self).logging_condition(L), process_solution=process_solution)
207class LogToFile(Hooks):
208 filename = 'myRun.pySDC'
209 time_increment = 0
210 allow_overwriting = False
211 counter = 0 # number of stored time points in the file
213 def __init__(self):
214 super().__init__()
215 self.outfile = None
216 self.t_next_log = 0
217 FieldsIO.ALLOW_OVERWRITE = self.allow_overwriting
219 def pre_run(self, step, level_number):
220 if level_number > 0:
221 return None
222 L = step.levels[level_number]
224 # setup outfile
225 if os.path.isfile(self.filename) and L.time > 0:
226 L.prob.setUpFieldsIO()
227 self.outfile = FieldsIO.fromFile(self.filename)
228 self.counter = len(self.outfile.times)
229 self.logger.info(
230 f'Set up file {self.filename!r} for writing output. This file already contains {self.counter} solutions up to t={self.outfile.times[-1]:.4f}.'
231 )
232 else:
233 self.outfile = L.prob.getOutputFile(self.filename)
234 self.logger.info(f'Set up file {self.filename!r} for writing output.')
236 # write initial conditions
237 if L.time not in self.outfile.times:
238 self.outfile.addField(time=L.time, field=L.prob.processSolutionForOutput(L.u[0]))
239 self.logger.info(f'Written initial conditions at t={L.time:4f} to file')
241 type(self).counter = len(self.outfile.times)
242 self.logger.info(f'Will write to disk every {self.time_increment:.4e} time units')
244 def post_step(self, step, level_number):
245 if level_number > 0:
246 return None
248 L = step.levels[level_number]
250 if self.t_next_log == 0:
251 self.t_next_log = L.time + self.time_increment
253 if L.time + L.dt >= self.t_next_log and not step.status.restart:
254 value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times]
255 if value_exists and not self.allow_overwriting:
256 raise DataError(f'Already have recorded data for time {L.time + L.dt} in this file!')
257 self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend))
258 self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file')
259 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
260 type(self).counter = len(self.outfile.times)
262 def post_run(self, step, level_number):
263 if level_number > 0:
264 return None
266 L = step.levels[level_number]
268 value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times]
269 if not value_exists:
270 self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend))
271 self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file')
272 self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
273 type(self).counter = len(self.outfile.times)
275 @classmethod
276 def load(cls, index):
277 data = {}
278 file = FieldsIO.fromFile(cls.filename)
279 file_entry = file.readField(idx=index)
280 data['u'] = file_entry[1]
281 data['t'] = file_entry[0]
282 return data