Coverage for pySDC / core / controller.py: 98%
199 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-20 16:04 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-20 16:04 +0000
1import logging
2import os
3import sys
4from typing import Any, Dict, List, Optional, Type, Union
5import numpy as np
7from pySDC.core.base_transfer import BaseTransfer
8from pySDC.helpers.pysdc_helper import FrozenClass
9from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
10from pySDC.implementations.hooks.default_hook import DefaultHooks
11from pySDC.implementations.hooks.log_timings import CPUTimings
14# short helper class to add params as attributes
15class _Pars(FrozenClass):
16 def __init__(self, params: Dict[str, Any]) -> None:
17 self.mssdc_jac: bool = True
18 self.predict_type: Optional[str] = None
19 self.all_to_done: bool = False
20 self.logger_level: int = 20
21 self.log_to_file: bool = False
22 self.dump_setup: bool = True
23 self.fname: str = 'run_pid' + str(os.getpid()) + '.log'
24 self.use_iteration_estimator: bool = False
26 for k, v in params.items():
27 setattr(self, k, v)
29 self._freeze()
32class Controller(object):
33 """
34 Base abstract controller class
35 """
37 def __init__(
38 self, controller_params: Dict[str, Any], description: Dict[str, Any], useMPI: Optional[bool] = None
39 ) -> None:
40 """
41 Initialization routine for the base controller
43 Args:
44 controller_params (dict): parameter set for the controller and the steps
45 """
46 self.useMPI: Optional[bool] = useMPI
47 self.description: Dict[str, Any] = description
49 # check if we have a hook on this list. If not, use default class.
50 self.__hooks: List[Any] = []
51 hook_classes: List[Type[Any]] = [DefaultHooks, CPUTimings]
52 user_hooks = controller_params.get('hook_class', [])
53 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
54 [self.add_hook(hook) for hook in hook_classes]
55 controller_params['hook_class'] = hook_classes
57 for hook in self.hooks:
58 hook.pre_setup(step=None, level_number=None)
60 self.params: _Pars = _Pars(controller_params)
62 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname)
63 self.logger: logging.Logger = logging.getLogger('controller')
65 if self.params.use_iteration_estimator and self.params.all_to_done:
66 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done')
68 self.base_convergence_controllers: List[Type[Any]] = [CheckConvergence]
69 self.setup_convergence_controllers(description)
71 @staticmethod
72 def __setup_custom_logger(
73 level: Optional[int] = None, log_to_file: Optional[bool] = None, fname: Optional[str] = None
74 ) -> None:
75 """
76 Helper function to set main parameters for the logging facility
78 Args:
79 level (int): level of logging
80 log_to_file (bool): flag to turn on/off logging to file
81 fname (str):
82 """
84 assert type(level) is int
86 # specify formats and handlers
87 if log_to_file:
88 file_formatter = logging.Formatter(
89 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s'
90 )
91 if os.path.isfile(fname):
92 file_handler = logging.FileHandler(fname, mode='a')
93 else:
94 file_handler = logging.FileHandler(fname, mode='w')
95 file_handler.setFormatter(file_formatter)
96 else:
97 file_handler = None
99 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s')
101 if level <= logging.DEBUG:
102 import warnings
104 warnings.warn('Running with debug output will degrade performance as all output is immediately flushed.')
106 class StreamFlushingHandler(logging.StreamHandler):
107 """
108 This will immediately flush any messages to the output.
109 """
111 def emit(self, record: logging.LogRecord) -> None:
112 super().emit(record)
113 self.flush()
115 std_handler = StreamFlushingHandler(sys.stdout)
116 else:
117 std_handler = logging.StreamHandler(sys.stdout)
119 std_handler.setFormatter(std_formatter)
121 # instantiate logger
122 logger = logging.getLogger('')
124 # remove handlers from previous calls to controller
125 for handler in logger.handlers[:]:
126 logger.removeHandler(handler)
128 logger.setLevel(level)
129 logger.addHandler(std_handler)
130 if log_to_file:
131 logger.addHandler(file_handler)
132 else:
133 pass
135 def add_hook(self, hook: Type[Any]) -> None:
136 """
137 Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
138 The hook is only added if a hook of the same class is not already present.
140 Args:
141 hook (pySDC.Hook): A hook class that is derived from the core hook class
143 Returns:
144 None
145 """
146 if hook not in [type(me) for me in self.hooks]:
147 self.__hooks += [hook()]
149 def welcome_message(self) -> None:
150 out = (
151 "Welcome to the one and only, really very astonishing and 87.3% bug free"
152 + "\n"
153 + r" _____ _____ _____ "
154 + "\n"
155 + r" / ____| __ \ / ____|"
156 + "\n"
157 + r" _ __ _ _| (___ | | | | | "
158 + "\n"
159 + r" | '_ \| | | |\___ \| | | | | "
160 + "\n"
161 + r" | |_) | |_| |____) | |__| | |____ "
162 + "\n"
163 + r" | .__/ \__, |_____/|_____/ \_____|"
164 + "\n"
165 + r" | | __/ | "
166 + "\n"
167 + r" |_| |___/ "
168 + "\n"
169 + r" "
170 )
171 self.logger.info(out)
173 def dump_setup(self, step: Any, controller_params: Dict[str, Any], description: Dict[str, Any]) -> None:
174 """
175 Helper function to dump the setup used for this controller
177 Args:
178 step (pySDC.Step.step): the step instance (will/should be the first one only)
179 controller_params (dict): controller parameters
180 description (dict): description of the problem
181 """
183 self.welcome_message()
184 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN'
185 self.logger.info(out)
186 out = '----------------------------------------------------------------------------------------------------\n\n'
187 out += 'Controller: %s\n' % self.__class__
188 for k, v in sorted(vars(self.params).items()):
189 if not k.startswith('_'):
190 if k in controller_params:
191 out += '--> %s = %s\n' % (k, v)
192 else:
193 out += ' %s = %s\n' % (k, v)
195 out += '\nStep: %s\n' % step.__class__
196 for k, v in sorted(vars(step.params).items()):
197 if not k.startswith('_'):
198 if k in description['step_params']:
199 out += '--> %s = %s\n' % (k, v)
200 else:
201 out += ' %s = %s\n' % (k, v)
202 out += f' Number of steps: {step.status.time_size}\n'
204 out += ' Level: %s\n' % step.levels[0].__class__
205 for L in step.levels:
206 out += ' Level %2i\n' % L.level_index
207 for k, v in sorted(vars(L.params).items()):
208 if not k.startswith('_'):
209 if k in description['level_params']:
210 out += '--> %s = %s\n' % (k, v)
211 else:
212 out += ' %s = %s\n' % (k, v)
213 out += '--> Problem: %s\n' % L.prob.__class__
214 for k, v in sorted(L.prob.params.items()):
215 if k in description['problem_params']:
216 out += '--> %s = %s\n' % (k, v)
217 else:
218 out += ' %s = %s\n' % (k, v)
219 out += '--> Data type u: %s\n' % L.prob.dtype_u
220 out += '--> Data type f: %s\n' % L.prob.dtype_f
221 out += '--> Sweeper: %s\n' % L.sweep.__class__
222 for k, v in sorted(vars(L.sweep.params).items()):
223 if not k.startswith('_'):
224 if k in description['sweeper_params']:
225 out += '--> %s = %s\n' % (k, v)
226 else:
227 out += ' %s = %s\n' % (k, v)
228 out += '--> Collocation: %s\n' % L.sweep.coll.__class__
230 if len(step.levels) > 1:
231 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer:
232 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__
233 else:
234 out += ' Base Transfer: %s\n' % step.base_transfer.__class__
235 for k, v in sorted(vars(step.base_transfer.params).items()):
236 if not k.startswith('_'):
237 if k in description['base_transfer_params']:
238 out += '--> %s = %s\n' % (k, v)
239 else:
240 out += ' %s = %s\n' % (k, v)
241 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__
242 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()):
243 if not k.startswith('_'):
244 if k in description['space_transfer_params']:
245 out += '--> %s = %s\n' % (k, v)
246 else:
247 out += ' %s = %s\n' % (k, v)
249 out += '\n'
250 out += self.get_convergence_controllers_as_table(description)
251 out += '\n'
252 self.logger.info(out)
254 out = '----------------------------------------------------------------------------------------------------'
255 self.logger.info(out)
256 out = 'Setup overview (--> user-defined, -> dependency) -- END\n'
257 self.logger.info(out)
259 def run(self, u0: Any, t0: float, Tend: float) -> Any:
260 """
261 Abstract interface to the run() method
263 Args:
264 u0: initial values
265 t0 (float): starting time
266 Tend (float): ending time
267 """
268 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)')
270 @property
271 def hooks(self) -> List[Any]:
272 """
273 Getter for the hooks
275 Returns:
276 pySDC.Hooks.hooks: hooks
277 """
278 return self.__hooks
280 def setup_convergence_controllers(self, description: Dict[str, Any]) -> None:
281 '''
282 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing
283 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration
284 count or a residual based stopping criterion, as well as all convergence controllers added to the description.
286 Args:
287 description (dict): The description object used to instantiate the controller
289 Returns:
290 None
291 '''
292 self.convergence_controllers: List[Any] = []
293 # List of indices specifying the order of convergence controllers
294 self.convergence_controller_order: List[int] = []
295 conv_classes = description.get('convergence_controllers', {})
297 # instantiate the convergence controllers
298 for conv_class, params in conv_classes.items():
299 self.add_convergence_controller(conv_class, description=description, params=params)
301 return None
303 def add_convergence_controller(
304 self,
305 convergence_controller: Type[Any],
306 description: Dict[str, Any],
307 params: Optional[Dict[str, Any]] = None,
308 allow_double: bool = False,
309 ) -> None:
310 '''
311 Add an individual convergence controller to the list of convergence controllers and instantiate it.
312 Afterwards, the order of the convergence controllers is updated.
314 Args:
315 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added
316 description (dict): The description object used to instantiate the controller
317 params (dict): Parameters for the convergence controller
318 allow_double (bool): Allow adding the same convergence controller multiple times
320 Returns:
321 None
322 '''
323 # check if we passed any sort of special params
324 params = {**({} if params is None else params), 'useMPI': self.useMPI}
326 # check if we already have the convergence controller or if we want to have it multiple times
327 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double:
328 self.convergence_controllers.append(convergence_controller(self, params, description))
330 # update ordering
331 orders = [C.params.control_order for C in self.convergence_controllers]
332 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)]
334 return None
336 def get_convergence_controllers_as_table(self, description: Dict[str, Any]) -> str:
337 '''
338 This function is for debugging purposes to keep track of the different convergence controllers and their order.
340 Args:
341 description (dict): Description of the problem
343 Returns:
344 str: Table of convergence controllers as a string
345 '''
346 out = 'Active convergence controllers:'
347 out += '\n | # | order | convergence controller'
348 out += '\n----+----+-------+---------------------------------------------------------------------------------------'
349 for i in range(len(self.convergence_controllers)):
350 C = self.convergence_controllers[self.convergence_controller_order[i]]
352 # figure out how the convergence controller was added
353 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user
354 user_added = '--> '
355 elif type(C) in self.base_convergence_controllers: # added by default
356 user_added = ' '
357 else: # added as dependency
358 user_added = ' -> '
360 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
362 return out
364 def return_stats(self) -> Dict[Any, Any]:
365 """
366 Return the merged stats from all hooks
368 Returns:
369 dict: Merged stats from all hooks
370 """
371 stats = {}
372 for hook in self.hooks:
373 stats = {**stats, **hook.return_stats()}
374 return stats
377class ParaDiagController(Controller):
379 def __init__(
380 self,
381 controller_params: Dict[str, Any],
382 description: Dict[str, Any],
383 n_steps: int,
384 useMPI: Optional[bool] = None,
385 ) -> None:
386 """
387 Initialization routine for ParaDiag controllers
389 Args:
390 num_procs: number of parallel time steps (still serial, though), can be 1
391 controller_params: parameter set for the controller and the steps
392 description: all the parameters to set up the rest (levels, problems, transfer, ...)
393 n_steps (int): Number of parallel steps
394 alpha (float): alpha parameter for ParaDiag
395 """
396 from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization
398 if QDiagonalization in description['sweeper_class'].__mro__:
399 description['sweeper_params']['ignore_ic'] = True
400 description['sweeper_params']['update_f_evals'] = False
401 else:
402 logging.getLogger('controller').warning(
403 f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.'
404 )
406 if controller_params.get('all_to_done', False):
407 raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`')
408 if 'alpha' not in controller_params.keys():
409 from pySDC.core.errors import ParameterError
411 raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!')
412 controller_params['average_jacobian'] = controller_params.get('average_jacobian', True)
414 controller_params['all_to_done'] = True
415 super().__init__(controller_params=controller_params, description=description, useMPI=useMPI)
417 self.n_steps: int = n_steps
419 def FFT_in_time(self, quantity: Any) -> None:
420 """
421 Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
423 Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT
424 with transposes!
425 """
426 if not hasattr(self, '__FFT_matrix'):
427 from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix
429 self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha)
431 self.apply_matrix(self.__FFT_matrix, quantity)
433 def iFFT_in_time(self, quantity: Any) -> None:
434 """
435 Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
436 """
437 if not hasattr(self, '__iFFT_matrix'):
438 from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix
440 self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha)
442 self.apply_matrix(self.__iFFT_matrix, quantity)