Coverage for pySDC/core/controller.py: 98%
198 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
1import logging
2import os
3import sys
4import numpy as np
6from pySDC.core.base_transfer import BaseTransfer
7from pySDC.helpers.pysdc_helper import FrozenClass
8from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9from pySDC.implementations.hooks.default_hook import DefaultHooks
10from pySDC.implementations.hooks.log_timings import CPUTimings
13# short helper class to add params as attributes
14class _Pars(FrozenClass):
15 def __init__(self, params):
16 self.mssdc_jac = True
17 self.predict_type = None
18 self.all_to_done = False
19 self.logger_level = 20
20 self.log_to_file = False
21 self.dump_setup = True
22 self.fname = 'run_pid' + str(os.getpid()) + '.log'
23 self.use_iteration_estimator = False
25 for k, v in params.items():
26 setattr(self, k, v)
28 self._freeze()
31class Controller(object):
32 """
33 Base abstract controller class
34 """
36 def __init__(self, controller_params, description, useMPI=None):
37 """
38 Initialization routine for the base controller
40 Args:
41 controller_params (dict): parameter set for the controller and the steps
42 """
43 self.useMPI = useMPI
44 self.description = description
46 # check if we have a hook on this list. If not, use default class.
47 self.__hooks = []
48 hook_classes = [DefaultHooks, CPUTimings]
49 user_hooks = controller_params.get('hook_class', [])
50 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
51 [self.add_hook(hook) for hook in hook_classes]
52 controller_params['hook_class'] = hook_classes
54 for hook in self.hooks:
55 hook.pre_setup(step=None, level_number=None)
57 self.params = _Pars(controller_params)
59 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname)
60 self.logger = logging.getLogger('controller')
62 if self.params.use_iteration_estimator and self.params.all_to_done:
63 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done')
65 self.base_convergence_controllers = [CheckConvergence]
66 self.setup_convergence_controllers(description)
68 @staticmethod
69 def __setup_custom_logger(level=None, log_to_file=None, fname=None):
70 """
71 Helper function to set main parameters for the logging facility
73 Args:
74 level (int): level of logging
75 log_to_file (bool): flag to turn on/off logging to file
76 fname (str):
77 """
79 assert type(level) is int
81 # specify formats and handlers
82 if log_to_file:
83 file_formatter = logging.Formatter(
84 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s'
85 )
86 if os.path.isfile(fname):
87 file_handler = logging.FileHandler(fname, mode='a')
88 else:
89 file_handler = logging.FileHandler(fname, mode='w')
90 file_handler.setFormatter(file_formatter)
91 else:
92 file_handler = None
94 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s')
96 if level <= logging.DEBUG:
97 import warnings
99 warnings.warn('Running with debug output will degrade performance as all output is immediately flushed.')
101 class StreamFlushingHandler(logging.StreamHandler):
102 """
103 This will immediately flush any messages to the output.
104 """
106 def emit(self, record):
107 super().emit(record)
108 self.flush()
110 std_handler = StreamFlushingHandler(sys.stdout)
111 else:
112 std_handler = logging.StreamHandler(sys.stdout)
114 std_handler.setFormatter(std_formatter)
116 # instantiate logger
117 logger = logging.getLogger('')
119 # remove handlers from previous calls to controller
120 for handler in logger.handlers[:]:
121 logger.removeHandler(handler)
123 logger.setLevel(level)
124 logger.addHandler(std_handler)
125 if log_to_file:
126 logger.addHandler(file_handler)
127 else:
128 pass
130 def add_hook(self, hook):
131 """
132 Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
133 The hook is only added if a hook of the same class is not already present.
135 Args:
136 hook (pySDC.Hook): A hook class that is derived from the core hook class
138 Returns:
139 None
140 """
141 if hook not in [type(me) for me in self.hooks]:
142 self.__hooks += [hook()]
144 def welcome_message(self):
145 out = (
146 "Welcome to the one and only, really very astonishing and 87.3% bug free"
147 + "\n"
148 + r" _____ _____ _____ "
149 + "\n"
150 + r" / ____| __ \ / ____|"
151 + "\n"
152 + r" _ __ _ _| (___ | | | | | "
153 + "\n"
154 + r" | '_ \| | | |\___ \| | | | | "
155 + "\n"
156 + r" | |_) | |_| |____) | |__| | |____ "
157 + "\n"
158 + r" | .__/ \__, |_____/|_____/ \_____|"
159 + "\n"
160 + r" | | __/ | "
161 + "\n"
162 + r" |_| |___/ "
163 + "\n"
164 + r" "
165 )
166 self.logger.info(out)
168 def dump_setup(self, step, controller_params, description):
169 """
170 Helper function to dump the setup used for this controller
172 Args:
173 step (pySDC.Step.step): the step instance (will/should be the first one only)
174 controller_params (dict): controller parameters
175 description (dict): description of the problem
176 """
178 self.welcome_message()
179 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN'
180 self.logger.info(out)
181 out = '----------------------------------------------------------------------------------------------------\n\n'
182 out += 'Controller: %s\n' % self.__class__
183 for k, v in sorted(vars(self.params).items()):
184 if not k.startswith('_'):
185 if k in controller_params:
186 out += '--> %s = %s\n' % (k, v)
187 else:
188 out += ' %s = %s\n' % (k, v)
190 out += '\nStep: %s\n' % step.__class__
191 for k, v in sorted(vars(step.params).items()):
192 if not k.startswith('_'):
193 if k in description['step_params']:
194 out += '--> %s = %s\n' % (k, v)
195 else:
196 out += ' %s = %s\n' % (k, v)
197 out += f' Number of steps: {step.status.time_size}\n'
199 out += ' Level: %s\n' % step.levels[0].__class__
200 for L in step.levels:
201 out += ' Level %2i\n' % L.level_index
202 for k, v in sorted(vars(L.params).items()):
203 if not k.startswith('_'):
204 if k in description['level_params']:
205 out += '--> %s = %s\n' % (k, v)
206 else:
207 out += ' %s = %s\n' % (k, v)
208 out += '--> Problem: %s\n' % L.prob.__class__
209 for k, v in sorted(L.prob.params.items()):
210 if k in description['problem_params']:
211 out += '--> %s = %s\n' % (k, v)
212 else:
213 out += ' %s = %s\n' % (k, v)
214 out += '--> Data type u: %s\n' % L.prob.dtype_u
215 out += '--> Data type f: %s\n' % L.prob.dtype_f
216 out += '--> Sweeper: %s\n' % L.sweep.__class__
217 for k, v in sorted(vars(L.sweep.params).items()):
218 if not k.startswith('_'):
219 if k in description['sweeper_params']:
220 out += '--> %s = %s\n' % (k, v)
221 else:
222 out += ' %s = %s\n' % (k, v)
223 out += '--> Collocation: %s\n' % L.sweep.coll.__class__
225 if len(step.levels) > 1:
226 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer:
227 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__
228 else:
229 out += ' Base Transfer: %s\n' % step.base_transfer.__class__
230 for k, v in sorted(vars(step.base_transfer.params).items()):
231 if not k.startswith('_'):
232 if k in description['base_transfer_params']:
233 out += '--> %s = %s\n' % (k, v)
234 else:
235 out += ' %s = %s\n' % (k, v)
236 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__
237 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()):
238 if not k.startswith('_'):
239 if k in description['space_transfer_params']:
240 out += '--> %s = %s\n' % (k, v)
241 else:
242 out += ' %s = %s\n' % (k, v)
244 out += '\n'
245 out += self.get_convergence_controllers_as_table(description)
246 out += '\n'
247 self.logger.info(out)
249 out = '----------------------------------------------------------------------------------------------------'
250 self.logger.info(out)
251 out = 'Setup overview (--> user-defined, -> dependency) -- END\n'
252 self.logger.info(out)
254 def run(self, u0, t0, Tend):
255 """
256 Abstract interface to the run() method
258 Args:
259 u0: initial values
260 t0 (float): starting time
261 Tend (float): ending time
262 """
263 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)')
265 @property
266 def hooks(self):
267 """
268 Getter for the hooks
270 Returns:
271 pySDC.Hooks.hooks: hooks
272 """
273 return self.__hooks
275 def setup_convergence_controllers(self, description):
276 '''
277 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing
278 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration
279 count or a residual based stopping criterion, as well as all convergence controllers added to the description.
281 Args:
282 description (dict): The description object used to instantiate the controller
284 Returns:
285 None
286 '''
287 self.convergence_controllers = []
288 self.convergence_controller_order = []
289 conv_classes = description.get('convergence_controllers', {})
291 # instantiate the convergence controllers
292 for conv_class, params in conv_classes.items():
293 self.add_convergence_controller(conv_class, description=description, params=params)
295 return None
297 def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False):
298 '''
299 Add an individual convergence controller to the list of convergence controllers and instantiate it.
300 Afterwards, the order of the convergence controllers is updated.
302 Args:
303 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added
304 description (dict): The description object used to instantiate the controller
305 params (dict): Parameters for the convergence controller
306 allow_double (bool): Allow adding the same convergence controller multiple times
308 Returns:
309 None
310 '''
311 # check if we passed any sort of special params
312 params = {**({} if params is None else params), 'useMPI': self.useMPI}
314 # check if we already have the convergence controller or if we want to have it multiple times
315 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double:
316 self.convergence_controllers.append(convergence_controller(self, params, description))
318 # update ordering
319 orders = [C.params.control_order for C in self.convergence_controllers]
320 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)]
322 return None
324 def get_convergence_controllers_as_table(self, description):
325 '''
326 This function is for debugging purposes to keep track of the different convergence controllers and their order.
328 Args:
329 description (dict): Description of the problem
331 Returns:
332 str: Table of convergence controllers as a string
333 '''
334 out = 'Active convergence controllers:'
335 out += '\n | # | order | convergence controller'
336 out += '\n----+----+-------+---------------------------------------------------------------------------------------'
337 for i in range(len(self.convergence_controllers)):
338 C = self.convergence_controllers[self.convergence_controller_order[i]]
340 # figure out how the convergence controller was added
341 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user
342 user_added = '--> '
343 elif type(C) in self.base_convergence_controllers: # added by default
344 user_added = ' '
345 else: # added as dependency
346 user_added = ' -> '
348 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
350 return out
352 def return_stats(self):
353 """
354 Return the merged stats from all hooks
356 Returns:
357 dict: Merged stats from all hooks
358 """
359 stats = {}
360 for hook in self.hooks:
361 stats = {**stats, **hook.return_stats()}
362 return stats
365class ParaDiagController(Controller):
367 def __init__(self, controller_params, description, n_steps, useMPI=None):
368 """
369 Initialization routine for ParaDiag controllers
371 Args:
372 num_procs: number of parallel time steps (still serial, though), can be 1
373 controller_params: parameter set for the controller and the steps
374 description: all the parameters to set up the rest (levels, problems, transfer, ...)
375 n_steps (int): Number of parallel steps
376 alpha (float): alpha parameter for ParaDiag
377 """
378 from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization
380 if QDiagonalization in description['sweeper_class'].__mro__:
381 description['sweeper_params']['ignore_ic'] = True
382 description['sweeper_params']['update_f_evals'] = False
383 else:
384 logging.getLogger('controller').warning(
385 f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.'
386 )
388 if controller_params.get('all_to_done', False):
389 raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`')
390 if 'alpha' not in controller_params.keys():
391 from pySDC.core.errors import ParameterError
393 raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!')
394 controller_params['average_jacobian'] = controller_params.get('average_jacobian', True)
396 controller_params['all_to_done'] = True
397 super().__init__(controller_params=controller_params, description=description, useMPI=useMPI)
399 self.n_steps = n_steps
401 def FFT_in_time(self, quantity):
402 """
403 Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
405 Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT
406 with transposes!
407 """
408 if not hasattr(self, '__FFT_matrix'):
409 from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix
411 self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha)
413 self.apply_matrix(self.__FFT_matrix, quantity)
415 def iFFT_in_time(self, quantity):
416 """
417 Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
418 """
419 if not hasattr(self, '__iFFT_matrix'):
420 from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix
422 self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha)
424 self.apply_matrix(self.__iFFT_matrix, quantity)