Coverage for pySDC/core/controller.py: 99%
165 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import logging
2import os
3import sys
4import numpy as np
6from pySDC.core.base_transfer import BaseTransfer
7from pySDC.helpers.pysdc_helper import FrozenClass
8from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9from pySDC.implementations.hooks.default_hook import DefaultHooks
10from pySDC.implementations.hooks.log_timings import CPUTimings
13# short helper class to add params as attributes
14class _Pars(FrozenClass):
15 def __init__(self, params):
16 self.mssdc_jac = True
17 self.predict_type = None
18 self.all_to_done = False
19 self.logger_level = 20
20 self.log_to_file = False
21 self.dump_setup = True
22 self.fname = 'run_pid' + str(os.getpid()) + '.log'
23 self.use_iteration_estimator = False
25 for k, v in params.items():
26 setattr(self, k, v)
28 self._freeze()
31class Controller(object):
32 """
33 Base abstract controller class
34 """
36 def __init__(self, controller_params, description, useMPI=None):
37 """
38 Initialization routine for the base controller
40 Args:
41 controller_params (dict): parameter set for the controller and the steps
42 """
43 self.useMPI = useMPI
45 # check if we have a hook on this list. If not, use default class.
46 self.__hooks = []
47 hook_classes = [DefaultHooks, CPUTimings]
48 user_hooks = controller_params.get('hook_class', [])
49 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
50 [self.add_hook(hook) for hook in hook_classes]
51 controller_params['hook_class'] = hook_classes
53 for hook in self.hooks:
54 hook.pre_setup(step=None, level_number=None)
56 self.params = _Pars(controller_params)
58 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname)
59 self.logger = logging.getLogger('controller')
61 if self.params.use_iteration_estimator and self.params.all_to_done:
62 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done')
64 self.base_convergence_controllers = [CheckConvergence]
65 self.setup_convergence_controllers(description)
67 @staticmethod
68 def __setup_custom_logger(level=None, log_to_file=None, fname=None):
69 """
70 Helper function to set main parameters for the logging facility
72 Args:
73 level (int): level of logging
74 log_to_file (bool): flag to turn on/off logging to file
75 fname (str):
76 """
78 assert type(level) is int
80 # specify formats and handlers
81 if log_to_file:
82 file_formatter = logging.Formatter(
83 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s'
84 )
85 if os.path.isfile(fname):
86 file_handler = logging.FileHandler(fname, mode='a')
87 else:
88 file_handler = logging.FileHandler(fname, mode='w')
89 file_handler.setFormatter(file_formatter)
90 else:
91 file_handler = None
93 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s')
94 std_handler = logging.StreamHandler(sys.stdout)
95 std_handler.setFormatter(std_formatter)
97 # instantiate logger
98 logger = logging.getLogger('')
100 # remove handlers from previous calls to controller
101 for handler in logger.handlers[:]:
102 logger.removeHandler(handler)
104 logger.setLevel(level)
105 logger.addHandler(std_handler)
106 if log_to_file:
107 logger.addHandler(file_handler)
108 else:
109 pass
111 def add_hook(self, hook):
112 """
113 Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
114 The hook is only added if a hook of the same class is not already present.
116 Args:
117 hook (pySDC.Hook): A hook class that is derived from the core hook class
119 Returns:
120 None
121 """
122 if hook not in [type(me) for me in self.hooks]:
123 self.__hooks += [hook()]
125 def welcome_message(self):
126 out = (
127 "Welcome to the one and only, really very astonishing and 87.3% bug free"
128 + "\n"
129 + r" _____ _____ _____ "
130 + "\n"
131 + r" / ____| __ \ / ____|"
132 + "\n"
133 + r" _ __ _ _| (___ | | | | | "
134 + "\n"
135 + r" | '_ \| | | |\___ \| | | | | "
136 + "\n"
137 + r" | |_) | |_| |____) | |__| | |____ "
138 + "\n"
139 + r" | .__/ \__, |_____/|_____/ \_____|"
140 + "\n"
141 + r" | | __/ | "
142 + "\n"
143 + r" |_| |___/ "
144 + "\n"
145 + r" "
146 )
147 self.logger.info(out)
149 def dump_setup(self, step, controller_params, description):
150 """
151 Helper function to dump the setup used for this controller
153 Args:
154 step (pySDC.Step.step): the step instance (will/should be the first one only)
155 controller_params (dict): controller parameters
156 description (dict): description of the problem
157 """
159 self.welcome_message()
160 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN'
161 self.logger.info(out)
162 out = '----------------------------------------------------------------------------------------------------\n\n'
163 out += 'Controller: %s\n' % self.__class__
164 for k, v in sorted(vars(self.params).items()):
165 if not k.startswith('_'):
166 if k in controller_params:
167 out += '--> %s = %s\n' % (k, v)
168 else:
169 out += ' %s = %s\n' % (k, v)
171 out += '\nStep: %s\n' % step.__class__
172 for k, v in sorted(vars(step.params).items()):
173 if not k.startswith('_'):
174 if k in description['step_params']:
175 out += '--> %s = %s\n' % (k, v)
176 else:
177 out += ' %s = %s\n' % (k, v)
178 out += f' Number of steps: {step.status.time_size}\n'
180 out += ' Level: %s\n' % step.levels[0].__class__
181 for L in step.levels:
182 out += ' Level %2i\n' % L.level_index
183 for k, v in sorted(vars(L.params).items()):
184 if not k.startswith('_'):
185 if k in description['level_params']:
186 out += '--> %s = %s\n' % (k, v)
187 else:
188 out += ' %s = %s\n' % (k, v)
189 out += '--> Problem: %s\n' % L.prob.__class__
190 for k, v in sorted(L.prob.params.items()):
191 if k in description['problem_params']:
192 out += '--> %s = %s\n' % (k, v)
193 else:
194 out += ' %s = %s\n' % (k, v)
195 out += '--> Data type u: %s\n' % L.prob.dtype_u
196 out += '--> Data type f: %s\n' % L.prob.dtype_f
197 out += '--> Sweeper: %s\n' % L.sweep.__class__
198 for k, v in sorted(vars(L.sweep.params).items()):
199 if not k.startswith('_'):
200 if k in description['sweeper_params']:
201 out += '--> %s = %s\n' % (k, v)
202 else:
203 out += ' %s = %s\n' % (k, v)
204 out += '--> Collocation: %s\n' % L.sweep.coll.__class__
206 if len(step.levels) > 1:
207 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer:
208 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__
209 else:
210 out += ' Base Transfer: %s\n' % step.base_transfer.__class__
211 for k, v in sorted(vars(step.base_transfer.params).items()):
212 if not k.startswith('_'):
213 if k in description['base_transfer_params']:
214 out += '--> %s = %s\n' % (k, v)
215 else:
216 out += ' %s = %s\n' % (k, v)
217 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__
218 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()):
219 if not k.startswith('_'):
220 if k in description['space_transfer_params']:
221 out += '--> %s = %s\n' % (k, v)
222 else:
223 out += ' %s = %s\n' % (k, v)
225 out += '\n'
226 out += self.get_convergence_controllers_as_table(description)
227 out += '\n'
228 self.logger.info(out)
230 out = '----------------------------------------------------------------------------------------------------'
231 self.logger.info(out)
232 out = 'Setup overview (--> user-defined, -> dependency) -- END\n'
233 self.logger.info(out)
235 def run(self, u0, t0, Tend):
236 """
237 Abstract interface to the run() method
239 Args:
240 u0: initial values
241 t0 (float): starting time
242 Tend (float): ending time
243 """
244 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)')
246 @property
247 def hooks(self):
248 """
249 Getter for the hooks
251 Returns:
252 pySDC.Hooks.hooks: hooks
253 """
254 return self.__hooks
256 def setup_convergence_controllers(self, description):
257 '''
258 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing
259 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration
260 count or a residual based stopping criterion, as well as all convergence controllers added to the description.
262 Args:
263 description (dict): The description object used to instantiate the controller
265 Returns:
266 None
267 '''
268 self.convergence_controllers = []
269 self.convergence_controller_order = []
270 conv_classes = description.get('convergence_controllers', {})
272 # instantiate the convergence controllers
273 for conv_class, params in conv_classes.items():
274 self.add_convergence_controller(conv_class, description=description, params=params)
276 return None
278 def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False):
279 '''
280 Add an individual convergence controller to the list of convergence controllers and instantiate it.
281 Afterwards, the order of the convergence controllers is updated.
283 Args:
284 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added
285 description (dict): The description object used to instantiate the controller
286 params (dict): Parameters for the convergence controller
287 allow_double (bool): Allow adding the same convergence controller multiple times
289 Returns:
290 None
291 '''
292 # check if we passed any sort of special params
293 params = {**({} if params is None else params), 'useMPI': self.useMPI}
295 # check if we already have the convergence controller or if we want to have it multiple times
296 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double:
297 self.convergence_controllers.append(convergence_controller(self, params, description))
299 # update ordering
300 orders = [C.params.control_order for C in self.convergence_controllers]
301 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)]
303 return None
305 def get_convergence_controllers_as_table(self, description):
306 '''
307 This function is for debugging purposes to keep track of the different convergence controllers and their order.
309 Args:
310 description (dict): Description of the problem
312 Returns:
313 str: Table of convergence controllers as a string
314 '''
315 out = 'Active convergence controllers:'
316 out += '\n | # | order | convergence controller'
317 out += '\n----+----+-------+---------------------------------------------------------------------------------------'
318 for i in range(len(self.convergence_controllers)):
319 C = self.convergence_controllers[self.convergence_controller_order[i]]
321 # figure out how the convergence controller was added
322 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user
323 user_added = '--> '
324 elif type(C) in self.base_convergence_controllers: # added by default
325 user_added = ' '
326 else: # added as dependency
327 user_added = ' -> '
329 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
331 return out
333 def return_stats(self):
334 """
335 Return the merged stats from all hooks
337 Returns:
338 dict: Merged stats from all hooks
339 """
340 stats = {}
341 for hook in self.hooks:
342 stats = {**stats, **hook.return_stats()}
343 return stats