Source code for core.controller

import logging
import os
import sys
import numpy as np

from pySDC.core.base_transfer import BaseTransfer
from pySDC.helpers.pysdc_helper import FrozenClass
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
from pySDC.implementations.hooks.default_hook import DefaultHooks
from pySDC.implementations.hooks.log_timings import CPUTimings


# short helper class to add params as attributes
class _Pars(FrozenClass):
    def __init__(self, params):
        self.mssdc_jac = True
        self.predict_type = None
        self.all_to_done = False
        self.logger_level = 20
        self.log_to_file = False
        self.dump_setup = True
        self.fname = 'run_pid' + str(os.getpid()) + '.log'
        self.use_iteration_estimator = False

        for k, v in params.items():
            setattr(self, k, v)

        self._freeze()


[docs] class Controller(object): """ Base abstract controller class """ def __init__(self, controller_params, description, useMPI=None): """ Initialization routine for the base controller Args: controller_params (dict): parameter set for the controller and the steps """ self.useMPI = useMPI # check if we have a hook on this list. If not, use default class. self.__hooks = [] hook_classes = [DefaultHooks, CPUTimings] user_hooks = controller_params.get('hook_class', []) hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] [self.add_hook(hook) for hook in hook_classes] controller_params['hook_class'] = hook_classes for hook in self.hooks: hook.pre_setup(step=None, level_number=None) self.params = _Pars(controller_params) self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname) self.logger = logging.getLogger('controller') if self.params.use_iteration_estimator and self.params.all_to_done: self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done') self.base_convergence_controllers = [CheckConvergence] self.setup_convergence_controllers(description) @staticmethod def __setup_custom_logger(level=None, log_to_file=None, fname=None): """ Helper function to set main parameters for the logging facility Args: level (int): level of logging log_to_file (bool): flag to turn on/off logging to file fname (str): """ assert type(level) is int # specify formats and handlers if log_to_file: file_formatter = logging.Formatter( fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s' ) if os.path.isfile(fname): file_handler = logging.FileHandler(fname, mode='a') else: file_handler = logging.FileHandler(fname, mode='w') file_handler.setFormatter(file_formatter) else: file_handler = None std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s') std_handler = logging.StreamHandler(sys.stdout) std_handler.setFormatter(std_formatter) # instantiate logger logger = logging.getLogger('') # remove handlers from previous calls to controller for handler in logger.handlers[:]: logger.removeHandler(handler) logger.setLevel(level) logger.addHandler(std_handler) if log_to_file: logger.addHandler(file_handler) else: pass
[docs] def add_hook(self, hook): """ Add a hook to the controller which will be called in addition to all other hooks whenever something happens. The hook is only added if a hook of the same class is not already present. Args: hook (pySDC.Hook): A hook class that is derived from the core hook class Returns: None """ if hook not in [type(me) for me in self.hooks]: self.__hooks += [hook()]
[docs] def welcome_message(self): out = ( "Welcome to the one and only, really very astonishing and 87.3% bug free" + "\n" + r" _____ _____ _____ " + "\n" + r" / ____| __ \ / ____|" + "\n" + r" _ __ _ _| (___ | | | | | " + "\n" + r" | '_ \| | | |\___ \| | | | | " + "\n" + r" | |_) | |_| |____) | |__| | |____ " + "\n" + r" | .__/ \__, |_____/|_____/ \_____|" + "\n" + r" | | __/ | " + "\n" + r" |_| |___/ " + "\n" + r" " ) self.logger.info(out)
[docs] def dump_setup(self, step, controller_params, description): """ Helper function to dump the setup used for this controller Args: step (pySDC.Step.step): the step instance (will/should be the first one only) controller_params (dict): controller parameters description (dict): description of the problem """ self.welcome_message() out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN' self.logger.info(out) out = '----------------------------------------------------------------------------------------------------\n\n' out += 'Controller: %s\n' % self.__class__ for k, v in sorted(vars(self.params).items()): if not k.startswith('_'): if k in controller_params: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += '\nStep: %s\n' % step.__class__ for k, v in sorted(vars(step.params).items()): if not k.startswith('_'): if k in description['step_params']: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += f' Number of steps: {step.status.time_size}\n' out += ' Level: %s\n' % step.levels[0].__class__ for L in step.levels: out += ' Level %2i\n' % L.level_index for k, v in sorted(vars(L.params).items()): if not k.startswith('_'): if k in description['level_params']: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += '--> Problem: %s\n' % L.prob.__class__ for k, v in sorted(L.prob.params.items()): if k in description['problem_params']: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += '--> Data type u: %s\n' % L.prob.dtype_u out += '--> Data type f: %s\n' % L.prob.dtype_f out += '--> Sweeper: %s\n' % L.sweep.__class__ for k, v in sorted(vars(L.sweep.params).items()): if not k.startswith('_'): if k in description['sweeper_params']: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += '--> Collocation: %s\n' % L.sweep.coll.__class__ if len(step.levels) > 1: if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer: out += '--> Base Transfer: %s\n' % step.base_transfer.__class__ else: out += ' Base Transfer: %s\n' % step.base_transfer.__class__ for k, v in sorted(vars(step.base_transfer.params).items()): if not k.startswith('_'): if k in description['base_transfer_params']: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__ for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()): if not k.startswith('_'): if k in description['space_transfer_params']: out += '--> %s = %s\n' % (k, v) else: out += ' %s = %s\n' % (k, v) out += '\n' out += self.get_convergence_controllers_as_table(description) out += '\n' self.logger.info(out) out = '----------------------------------------------------------------------------------------------------' self.logger.info(out) out = 'Setup overview (--> user-defined, -> dependency) -- END\n' self.logger.info(out)
[docs] def run(self, u0, t0, Tend): """ Abstract interface to the run() method Args: u0: initial values t0 (float): starting time Tend (float): ending time """ raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)')
@property def hooks(self): """ Getter for the hooks Returns: pySDC.Hooks.hooks: hooks """ return self.__hooks
[docs] def setup_convergence_controllers(self, description): ''' Setup variables needed for convergence controllers, notably a list containing all of them and a list containing their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration count or a residual based stopping criterion, as well as all convergence controllers added to the description. Args: description (dict): The description object used to instantiate the controller Returns: None ''' self.convergence_controllers = [] self.convergence_controller_order = [] conv_classes = description.get('convergence_controllers', {}) # instantiate the convergence controllers for conv_class, params in conv_classes.items(): self.add_convergence_controller(conv_class, description=description, params=params) return None
[docs] def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False): ''' Add an individual convergence controller to the list of convergence controllers and instantiate it. Afterwards, the order of the convergence controllers is updated. Args: convergence_controller (pySDC.ConvergenceController): The convergence controller to be added description (dict): The description object used to instantiate the controller params (dict): Parameters for the convergence controller allow_double (bool): Allow adding the same convergence controller multiple times Returns: None ''' # check if we passed any sort of special params params = {**({} if params is None else params), 'useMPI': self.useMPI} # check if we already have the convergence controller or if we want to have it multiple times if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double: self.convergence_controllers.append(convergence_controller(self, params, description)) # update ordering orders = [C.params.control_order for C in self.convergence_controllers] self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)] return None
[docs] def get_convergence_controllers_as_table(self, description): ''' This function is for debugging purposes to keep track of the different convergence controllers and their order. Args: description (dict): Description of the problem Returns: str: Table of convergence controllers as a string ''' out = 'Active convergence controllers:' out += '\n | # | order | convergence controller' out += '\n----+----+-------+---------------------------------------------------------------------------------------' for i in range(len(self.convergence_controllers)): C = self.convergence_controllers[self.convergence_controller_order[i]] # figure out how the convergence controller was added if type(C) in description.get('convergence_controllers', {}).keys(): # added by user user_added = '--> ' elif type(C) in self.base_convergence_controllers: # added by default user_added = ' ' else: # added as dependency user_added = ' -> ' out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' return out
[docs] def return_stats(self): """ Return the merged stats from all hooks Returns: dict: Merged stats from all hooks """ stats = {} for hook in self.hooks: stats = {**stats, **hook.return_stats()} return stats