Coverage for pySDC/core/controller.py: 99%
164 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1import logging
2import os
3import sys
4import numpy as np
6from pySDC.core.base_transfer import BaseTransfer
7from pySDC.helpers.pysdc_helper import FrozenClass
8from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
9from pySDC.implementations.hooks.default_hook import DefaultHooks
12# short helper class to add params as attributes
13class _Pars(FrozenClass):
14 def __init__(self, params):
15 self.mssdc_jac = True
16 self.predict_type = None
17 self.all_to_done = False
18 self.logger_level = 20
19 self.log_to_file = False
20 self.dump_setup = True
21 self.fname = 'run_pid' + str(os.getpid()) + '.log'
22 self.use_iteration_estimator = False
24 for k, v in params.items():
25 setattr(self, k, v)
27 self._freeze()
30class Controller(object):
31 """
32 Base abstract controller class
33 """
35 def __init__(self, controller_params, description, useMPI=None):
36 """
37 Initialization routine for the base controller
39 Args:
40 controller_params (dict): parameter set for the controller and the steps
41 """
42 self.useMPI = useMPI
44 # check if we have a hook on this list. If not, use default class.
45 self.__hooks = []
46 hook_classes = [DefaultHooks]
47 user_hooks = controller_params.get('hook_class', [])
48 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
49 [self.add_hook(hook) for hook in hook_classes]
50 controller_params['hook_class'] = hook_classes
52 for hook in self.hooks:
53 hook.pre_setup(step=None, level_number=None)
55 self.params = _Pars(controller_params)
57 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname)
58 self.logger = logging.getLogger('controller')
60 if self.params.use_iteration_estimator and self.params.all_to_done:
61 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done')
63 self.base_convergence_controllers = [CheckConvergence]
64 self.setup_convergence_controllers(description)
66 @staticmethod
67 def __setup_custom_logger(level=None, log_to_file=None, fname=None):
68 """
69 Helper function to set main parameters for the logging facility
71 Args:
72 level (int): level of logging
73 log_to_file (bool): flag to turn on/off logging to file
74 fname (str):
75 """
77 assert type(level) is int
79 # specify formats and handlers
80 if log_to_file:
81 file_formatter = logging.Formatter(
82 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s'
83 )
84 if os.path.isfile(fname):
85 file_handler = logging.FileHandler(fname, mode='a')
86 else:
87 file_handler = logging.FileHandler(fname, mode='w')
88 file_handler.setFormatter(file_formatter)
89 else:
90 file_handler = None
92 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s')
93 std_handler = logging.StreamHandler(sys.stdout)
94 std_handler.setFormatter(std_formatter)
96 # instantiate logger
97 logger = logging.getLogger('')
99 # remove handlers from previous calls to controller
100 for handler in logger.handlers[:]:
101 logger.removeHandler(handler)
103 logger.setLevel(level)
104 logger.addHandler(std_handler)
105 if log_to_file:
106 logger.addHandler(file_handler)
107 else:
108 pass
110 def add_hook(self, hook):
111 """
112 Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
113 The hook is only added if a hook of the same class is not already present.
115 Args:
116 hook (pySDC.Hook): A hook class that is derived from the core hook class
118 Returns:
119 None
120 """
121 if hook not in [type(me) for me in self.hooks]:
122 self.__hooks += [hook()]
124 def welcome_message(self):
125 out = (
126 "Welcome to the one and only, really very astonishing and 87.3% bug free"
127 + "\n"
128 + r" _____ _____ _____ "
129 + "\n"
130 + r" / ____| __ \ / ____|"
131 + "\n"
132 + r" _ __ _ _| (___ | | | | | "
133 + "\n"
134 + r" | '_ \| | | |\___ \| | | | | "
135 + "\n"
136 + r" | |_) | |_| |____) | |__| | |____ "
137 + "\n"
138 + r" | .__/ \__, |_____/|_____/ \_____|"
139 + "\n"
140 + r" | | __/ | "
141 + "\n"
142 + r" |_| |___/ "
143 + "\n"
144 + r" "
145 )
146 self.logger.info(out)
148 def dump_setup(self, step, controller_params, description):
149 """
150 Helper function to dump the setup used for this controller
152 Args:
153 step (pySDC.Step.step): the step instance (will/should be the first one only)
154 controller_params (dict): controller parameters
155 description (dict): description of the problem
156 """
158 self.welcome_message()
159 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN'
160 self.logger.info(out)
161 out = '----------------------------------------------------------------------------------------------------\n\n'
162 out += 'Controller: %s\n' % self.__class__
163 for k, v in sorted(vars(self.params).items()):
164 if not k.startswith('_'):
165 if k in controller_params:
166 out += '--> %s = %s\n' % (k, v)
167 else:
168 out += ' %s = %s\n' % (k, v)
170 out += '\nStep: %s\n' % step.__class__
171 for k, v in sorted(vars(step.params).items()):
172 if not k.startswith('_'):
173 if k in description['step_params']:
174 out += '--> %s = %s\n' % (k, v)
175 else:
176 out += ' %s = %s\n' % (k, v)
177 out += f' Number of steps: {step.status.time_size}\n'
179 out += ' Level: %s\n' % step.levels[0].__class__
180 for L in step.levels:
181 out += ' Level %2i\n' % L.level_index
182 for k, v in sorted(vars(L.params).items()):
183 if not k.startswith('_'):
184 if k in description['level_params']:
185 out += '--> %s = %s\n' % (k, v)
186 else:
187 out += ' %s = %s\n' % (k, v)
188 out += '--> Problem: %s\n' % L.prob.__class__
189 for k, v in sorted(L.prob.params.items()):
190 if k in description['problem_params']:
191 out += '--> %s = %s\n' % (k, v)
192 else:
193 out += ' %s = %s\n' % (k, v)
194 out += '--> Data type u: %s\n' % L.prob.dtype_u
195 out += '--> Data type f: %s\n' % L.prob.dtype_f
196 out += '--> Sweeper: %s\n' % L.sweep.__class__
197 for k, v in sorted(vars(L.sweep.params).items()):
198 if not k.startswith('_'):
199 if k in description['sweeper_params']:
200 out += '--> %s = %s\n' % (k, v)
201 else:
202 out += ' %s = %s\n' % (k, v)
203 out += '--> Collocation: %s\n' % L.sweep.coll.__class__
205 if len(step.levels) > 1:
206 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer:
207 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__
208 else:
209 out += ' Base Transfer: %s\n' % step.base_transfer.__class__
210 for k, v in sorted(vars(step.base_transfer.params).items()):
211 if not k.startswith('_'):
212 if k in description['base_transfer_params']:
213 out += '--> %s = %s\n' % (k, v)
214 else:
215 out += ' %s = %s\n' % (k, v)
216 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__
217 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()):
218 if not k.startswith('_'):
219 if k in description['space_transfer_params']:
220 out += '--> %s = %s\n' % (k, v)
221 else:
222 out += ' %s = %s\n' % (k, v)
224 out += '\n'
225 out += self.get_convergence_controllers_as_table(description)
226 out += '\n'
227 self.logger.info(out)
229 out = '----------------------------------------------------------------------------------------------------'
230 self.logger.info(out)
231 out = 'Setup overview (--> user-defined, -> dependency) -- END\n'
232 self.logger.info(out)
234 def run(self, u0, t0, Tend):
235 """
236 Abstract interface to the run() method
238 Args:
239 u0: initial values
240 t0 (float): starting time
241 Tend (float): ending time
242 """
243 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)')
245 @property
246 def hooks(self):
247 """
248 Getter for the hooks
250 Returns:
251 pySDC.Hooks.hooks: hooks
252 """
253 return self.__hooks
255 def setup_convergence_controllers(self, description):
256 '''
257 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing
258 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration
259 count or a residual based stopping criterion, as well as all convergence controllers added to the description.
261 Args:
262 description (dict): The description object used to instantiate the controller
264 Returns:
265 None
266 '''
267 self.convergence_controllers = []
268 self.convergence_controller_order = []
269 conv_classes = description.get('convergence_controllers', {})
271 # instantiate the convergence controllers
272 for conv_class, params in conv_classes.items():
273 self.add_convergence_controller(conv_class, description=description, params=params)
275 return None
277 def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False):
278 '''
279 Add an individual convergence controller to the list of convergence controllers and instantiate it.
280 Afterwards, the order of the convergence controllers is updated.
282 Args:
283 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added
284 description (dict): The description object used to instantiate the controller
285 params (dict): Parameters for the convergence controller
286 allow_double (bool): Allow adding the same convergence controller multiple times
288 Returns:
289 None
290 '''
291 # check if we passed any sort of special params
292 params = {**({} if params is None else params), 'useMPI': self.useMPI}
294 # check if we already have the convergence controller or if we want to have it multiple times
295 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double:
296 self.convergence_controllers.append(convergence_controller(self, params, description))
298 # update ordering
299 orders = [C.params.control_order for C in self.convergence_controllers]
300 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)]
302 return None
304 def get_convergence_controllers_as_table(self, description):
305 '''
306 This function is for debugging purposes to keep track of the different convergence controllers and their order.
308 Args:
309 description (dict): Description of the problem
311 Returns:
312 str: Table of convergence controllers as a string
313 '''
314 out = 'Active convergence controllers:'
315 out += '\n | # | order | convergence controller'
316 out += '\n----+----+-------+---------------------------------------------------------------------------------------'
317 for i in range(len(self.convergence_controllers)):
318 C = self.convergence_controllers[self.convergence_controller_order[i]]
320 # figure out how the convergence controller was added
321 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user
322 user_added = '--> '
323 elif type(C) in self.base_convergence_controllers: # added by default
324 user_added = ' '
325 else: # added as dependency
326 user_added = ' -> '
328 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
330 return out
332 def return_stats(self):
333 """
334 Return the merged stats from all hooks
336 Returns:
337 dict: Merged stats from all hooks
338 """
339 stats = {}
340 for hook in self.hooks:
341 stats = {**stats, **hook.return_stats()}
342 return stats