import logging
import os
import sys
from typing import Any, Dict, List, Optional, Type, Union
import numpy as np
from pySDC.core.base_transfer import BaseTransfer
from pySDC.helpers.pysdc_helper import FrozenClass
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
from pySDC.implementations.hooks.default_hook import DefaultHooks
from pySDC.implementations.hooks.log_timings import CPUTimings
# short helper class to add params as attributes
class _Pars(FrozenClass):
def __init__(self, params: Dict[str, Any]) -> None:
self.mssdc_jac: bool = True
self.predict_type: Optional[str] = None
self.all_to_done: bool = False
self.logger_level: int = 20
self.log_to_file: bool = False
self.dump_setup: bool = True
self.fname: str = 'run_pid' + str(os.getpid()) + '.log'
self.use_iteration_estimator: bool = False
for k, v in params.items():
setattr(self, k, v)
self._freeze()
[docs]
class Controller(object):
"""
Base abstract controller class
"""
def __init__(
self, controller_params: Dict[str, Any], description: Dict[str, Any], useMPI: Optional[bool] = None
) -> None:
"""
Initialization routine for the base controller
Args:
controller_params (dict): parameter set for the controller and the steps
"""
self.useMPI: Optional[bool] = useMPI
self.description: Dict[str, Any] = description
# check if we have a hook on this list. If not, use default class.
self.__hooks: List[Any] = []
hook_classes: List[Type[Any]] = [DefaultHooks, CPUTimings]
user_hooks = controller_params.get('hook_class', [])
hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
[self.add_hook(hook) for hook in hook_classes]
controller_params['hook_class'] = hook_classes
for hook in self.hooks:
hook.pre_setup(step=None, level_number=None)
self.params: _Pars = _Pars(controller_params)
self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname)
self.logger: logging.Logger = logging.getLogger('controller')
if self.params.use_iteration_estimator and self.params.all_to_done:
self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done')
self.base_convergence_controllers: List[Type[Any]] = [CheckConvergence]
self.setup_convergence_controllers(description)
@staticmethod
def __setup_custom_logger(
level: Optional[int] = None, log_to_file: Optional[bool] = None, fname: Optional[str] = None
) -> None:
"""
Helper function to set main parameters for the logging facility
Args:
level (int): level of logging
log_to_file (bool): flag to turn on/off logging to file
fname (str):
"""
assert type(level) is int
# specify formats and handlers
if log_to_file:
file_formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s'
)
if os.path.isfile(fname):
file_handler = logging.FileHandler(fname, mode='a')
else:
file_handler = logging.FileHandler(fname, mode='w')
file_handler.setFormatter(file_formatter)
else:
file_handler = None
std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s')
if level <= logging.DEBUG:
import warnings
warnings.warn('Running with debug output will degrade performance as all output is immediately flushed.')
class StreamFlushingHandler(logging.StreamHandler):
"""
This will immediately flush any messages to the output.
"""
def emit(self, record: logging.LogRecord) -> None:
super().emit(record)
self.flush()
std_handler = StreamFlushingHandler(sys.stdout)
else:
std_handler = logging.StreamHandler(sys.stdout)
std_handler.setFormatter(std_formatter)
# instantiate logger
logger = logging.getLogger('')
# remove handlers from previous calls to controller
for handler in logger.handlers[:]:
logger.removeHandler(handler)
logger.setLevel(level)
logger.addHandler(std_handler)
if log_to_file:
logger.addHandler(file_handler)
else:
pass
[docs]
def add_hook(self, hook: Type[Any]) -> None:
"""
Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
The hook is only added if a hook of the same class is not already present.
Args:
hook (pySDC.Hook): A hook class that is derived from the core hook class
Returns:
None
"""
if hook not in [type(me) for me in self.hooks]:
self.__hooks += [hook()]
[docs]
def welcome_message(self) -> None:
out = (
"Welcome to the one and only, really very astonishing and 87.3% bug free"
+ "\n"
+ r" _____ _____ _____ "
+ "\n"
+ r" / ____| __ \ / ____|"
+ "\n"
+ r" _ __ _ _| (___ | | | | | "
+ "\n"
+ r" | '_ \| | | |\___ \| | | | | "
+ "\n"
+ r" | |_) | |_| |____) | |__| | |____ "
+ "\n"
+ r" | .__/ \__, |_____/|_____/ \_____|"
+ "\n"
+ r" | | __/ | "
+ "\n"
+ r" |_| |___/ "
+ "\n"
+ r" "
)
self.logger.info(out)
[docs]
def dump_setup(self, step: Any, controller_params: Dict[str, Any], description: Dict[str, Any]) -> None:
"""
Helper function to dump the setup used for this controller
Args:
step (pySDC.Step.step): the step instance (will/should be the first one only)
controller_params (dict): controller parameters
description (dict): description of the problem
"""
self.welcome_message()
out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN'
self.logger.info(out)
out = '----------------------------------------------------------------------------------------------------\n\n'
out += 'Controller: %s\n' % self.__class__
for k, v in sorted(vars(self.params).items()):
if not k.startswith('_'):
if k in controller_params:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '\nStep: %s\n' % step.__class__
for k, v in sorted(vars(step.params).items()):
if not k.startswith('_'):
if k in description['step_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += f' Number of steps: {step.status.time_size}\n'
out += ' Level: %s\n' % step.levels[0].__class__
for L in step.levels:
out += ' Level %2i\n' % L.level_index
for k, v in sorted(vars(L.params).items()):
if not k.startswith('_'):
if k in description['level_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '--> Problem: %s\n' % L.prob.__class__
for k, v in sorted(L.prob.params.items()):
if k in description['problem_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '--> Data type u: %s\n' % L.prob.dtype_u
out += '--> Data type f: %s\n' % L.prob.dtype_f
out += '--> Sweeper: %s\n' % L.sweep.__class__
for k, v in sorted(vars(L.sweep.params).items()):
if not k.startswith('_'):
if k in description['sweeper_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '--> Collocation: %s\n' % L.sweep.coll.__class__
if len(step.levels) > 1:
if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer:
out += '--> Base Transfer: %s\n' % step.base_transfer.__class__
else:
out += ' Base Transfer: %s\n' % step.base_transfer.__class__
for k, v in sorted(vars(step.base_transfer.params).items()):
if not k.startswith('_'):
if k in description['base_transfer_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__
for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()):
if not k.startswith('_'):
if k in description['space_transfer_params']:
out += '--> %s = %s\n' % (k, v)
else:
out += ' %s = %s\n' % (k, v)
out += '\n'
out += self.get_convergence_controllers_as_table(description)
out += '\n'
self.logger.info(out)
out = '----------------------------------------------------------------------------------------------------'
self.logger.info(out)
out = 'Setup overview (--> user-defined, -> dependency) -- END\n'
self.logger.info(out)
[docs]
def run(self, u0: Any, t0: float, Tend: float) -> Any:
"""
Abstract interface to the run() method
Args:
u0: initial values
t0 (float): starting time
Tend (float): ending time
"""
raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)')
@property
def hooks(self) -> List[Any]:
"""
Getter for the hooks
Returns:
pySDC.Hooks.hooks: hooks
"""
return self.__hooks
[docs]
def setup_convergence_controllers(self, description: Dict[str, Any]) -> None:
'''
Setup variables needed for convergence controllers, notably a list containing all of them and a list containing
their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration
count or a residual based stopping criterion, as well as all convergence controllers added to the description.
Args:
description (dict): The description object used to instantiate the controller
Returns:
None
'''
self.convergence_controllers: List[Any] = []
# List of indices specifying the order of convergence controllers
self.convergence_controller_order: List[int] = []
conv_classes = description.get('convergence_controllers', {})
# instantiate the convergence controllers
for conv_class, params in conv_classes.items():
self.add_convergence_controller(conv_class, description=description, params=params)
return None
[docs]
def add_convergence_controller(
self,
convergence_controller: Type[Any],
description: Dict[str, Any],
params: Optional[Dict[str, Any]] = None,
allow_double: bool = False,
) -> None:
'''
Add an individual convergence controller to the list of convergence controllers and instantiate it.
Afterwards, the order of the convergence controllers is updated.
Args:
convergence_controller (pySDC.ConvergenceController): The convergence controller to be added
description (dict): The description object used to instantiate the controller
params (dict): Parameters for the convergence controller
allow_double (bool): Allow adding the same convergence controller multiple times
Returns:
None
'''
# check if we passed any sort of special params
params = {**({} if params is None else params), 'useMPI': self.useMPI}
# check if we already have the convergence controller or if we want to have it multiple times
if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double:
self.convergence_controllers.append(convergence_controller(self, params, description))
# update ordering
orders = [C.params.control_order for C in self.convergence_controllers]
self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)]
return None
[docs]
def get_convergence_controllers_as_table(self, description: Dict[str, Any]) -> str:
'''
This function is for debugging purposes to keep track of the different convergence controllers and their order.
Args:
description (dict): Description of the problem
Returns:
str: Table of convergence controllers as a string
'''
out = 'Active convergence controllers:'
out += '\n | # | order | convergence controller'
out += '\n----+----+-------+---------------------------------------------------------------------------------------'
for i in range(len(self.convergence_controllers)):
C = self.convergence_controllers[self.convergence_controller_order[i]]
# figure out how the convergence controller was added
if type(C) in description.get('convergence_controllers', {}).keys(): # added by user
user_added = '--> '
elif type(C) in self.base_convergence_controllers: # added by default
user_added = ' '
else: # added as dependency
user_added = ' -> '
out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
return out
[docs]
def return_stats(self) -> Dict[Any, Any]:
"""
Return the merged stats from all hooks
Returns:
dict: Merged stats from all hooks
"""
stats = {}
for hook in self.hooks:
stats = {**stats, **hook.return_stats()}
return stats
[docs]
class ParaDiagController(Controller):
def __init__(
self,
controller_params: Dict[str, Any],
description: Dict[str, Any],
n_steps: int,
useMPI: Optional[bool] = None,
) -> None:
"""
Initialization routine for ParaDiag controllers
Args:
num_procs: number of parallel time steps (still serial, though), can be 1
controller_params: parameter set for the controller and the steps
description: all the parameters to set up the rest (levels, problems, transfer, ...)
n_steps (int): Number of parallel steps
alpha (float): alpha parameter for ParaDiag
"""
from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization
if QDiagonalization in description['sweeper_class'].__mro__:
description['sweeper_params']['ignore_ic'] = True
description['sweeper_params']['update_f_evals'] = False
else:
logging.getLogger('controller').warning(
f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.'
)
if controller_params.get('all_to_done', False):
raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`')
if 'alpha' not in controller_params.keys():
from pySDC.core.errors import ParameterError
raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!')
controller_params['average_jacobian'] = controller_params.get('average_jacobian', True)
controller_params['all_to_done'] = True
super().__init__(controller_params=controller_params, description=description, useMPI=useMPI)
self.n_steps: int = n_steps
[docs]
def FFT_in_time(self, quantity: Any) -> None:
"""
Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT
with transposes!
"""
if not hasattr(self, '__FFT_matrix'):
from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix
self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha)
self.apply_matrix(self.__FFT_matrix, quantity)
[docs]
def iFFT_in_time(self, quantity: Any) -> None:
"""
Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag
"""
if not hasattr(self, '__iFFT_matrix'):
from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix
self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha)
self.apply_matrix(self.__iFFT_matrix, quantity)