Coverage for pySDC/core/controller.py: 99%

165 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import logging 

2import os 

3import sys 

4import numpy as np 

5 

6from pySDC.core.base_transfer import BaseTransfer 

7from pySDC.helpers.pysdc_helper import FrozenClass 

8from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

9from pySDC.implementations.hooks.default_hook import DefaultHooks 

10from pySDC.implementations.hooks.log_timings import CPUTimings 

11 

12 

13# short helper class to add params as attributes 

14class _Pars(FrozenClass): 

15 def __init__(self, params): 

16 self.mssdc_jac = True 

17 self.predict_type = None 

18 self.all_to_done = False 

19 self.logger_level = 20 

20 self.log_to_file = False 

21 self.dump_setup = True 

22 self.fname = 'run_pid' + str(os.getpid()) + '.log' 

23 self.use_iteration_estimator = False 

24 

25 for k, v in params.items(): 

26 setattr(self, k, v) 

27 

28 self._freeze() 

29 

30 

31class Controller(object): 

32 """ 

33 Base abstract controller class 

34 """ 

35 

36 def __init__(self, controller_params, description, useMPI=None): 

37 """ 

38 Initialization routine for the base controller 

39 

40 Args: 

41 controller_params (dict): parameter set for the controller and the steps 

42 """ 

43 self.useMPI = useMPI 

44 

45 # check if we have a hook on this list. If not, use default class. 

46 self.__hooks = [] 

47 hook_classes = [DefaultHooks, CPUTimings] 

48 user_hooks = controller_params.get('hook_class', []) 

49 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] 

50 [self.add_hook(hook) for hook in hook_classes] 

51 controller_params['hook_class'] = hook_classes 

52 

53 for hook in self.hooks: 

54 hook.pre_setup(step=None, level_number=None) 

55 

56 self.params = _Pars(controller_params) 

57 

58 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname) 

59 self.logger = logging.getLogger('controller') 

60 

61 if self.params.use_iteration_estimator and self.params.all_to_done: 

62 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done') 

63 

64 self.base_convergence_controllers = [CheckConvergence] 

65 self.setup_convergence_controllers(description) 

66 

67 @staticmethod 

68 def __setup_custom_logger(level=None, log_to_file=None, fname=None): 

69 """ 

70 Helper function to set main parameters for the logging facility 

71 

72 Args: 

73 level (int): level of logging 

74 log_to_file (bool): flag to turn on/off logging to file 

75 fname (str): 

76 """ 

77 

78 assert type(level) is int 

79 

80 # specify formats and handlers 

81 if log_to_file: 

82 file_formatter = logging.Formatter( 

83 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s' 

84 ) 

85 if os.path.isfile(fname): 

86 file_handler = logging.FileHandler(fname, mode='a') 

87 else: 

88 file_handler = logging.FileHandler(fname, mode='w') 

89 file_handler.setFormatter(file_formatter) 

90 else: 

91 file_handler = None 

92 

93 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s') 

94 std_handler = logging.StreamHandler(sys.stdout) 

95 std_handler.setFormatter(std_formatter) 

96 

97 # instantiate logger 

98 logger = logging.getLogger('') 

99 

100 # remove handlers from previous calls to controller 

101 for handler in logger.handlers[:]: 

102 logger.removeHandler(handler) 

103 

104 logger.setLevel(level) 

105 logger.addHandler(std_handler) 

106 if log_to_file: 

107 logger.addHandler(file_handler) 

108 else: 

109 pass 

110 

111 def add_hook(self, hook): 

112 """ 

113 Add a hook to the controller which will be called in addition to all other hooks whenever something happens. 

114 The hook is only added if a hook of the same class is not already present. 

115 

116 Args: 

117 hook (pySDC.Hook): A hook class that is derived from the core hook class 

118 

119 Returns: 

120 None 

121 """ 

122 if hook not in [type(me) for me in self.hooks]: 

123 self.__hooks += [hook()] 

124 

125 def welcome_message(self): 

126 out = ( 

127 "Welcome to the one and only, really very astonishing and 87.3% bug free" 

128 + "\n" 

129 + r" _____ _____ _____ " 

130 + "\n" 

131 + r" / ____| __ \ / ____|" 

132 + "\n" 

133 + r" _ __ _ _| (___ | | | | | " 

134 + "\n" 

135 + r" | '_ \| | | |\___ \| | | | | " 

136 + "\n" 

137 + r" | |_) | |_| |____) | |__| | |____ " 

138 + "\n" 

139 + r" | .__/ \__, |_____/|_____/ \_____|" 

140 + "\n" 

141 + r" | | __/ | " 

142 + "\n" 

143 + r" |_| |___/ " 

144 + "\n" 

145 + r" " 

146 ) 

147 self.logger.info(out) 

148 

149 def dump_setup(self, step, controller_params, description): 

150 """ 

151 Helper function to dump the setup used for this controller 

152 

153 Args: 

154 step (pySDC.Step.step): the step instance (will/should be the first one only) 

155 controller_params (dict): controller parameters 

156 description (dict): description of the problem 

157 """ 

158 

159 self.welcome_message() 

160 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN' 

161 self.logger.info(out) 

162 out = '----------------------------------------------------------------------------------------------------\n\n' 

163 out += 'Controller: %s\n' % self.__class__ 

164 for k, v in sorted(vars(self.params).items()): 

165 if not k.startswith('_'): 

166 if k in controller_params: 

167 out += '--> %s = %s\n' % (k, v) 

168 else: 

169 out += ' %s = %s\n' % (k, v) 

170 

171 out += '\nStep: %s\n' % step.__class__ 

172 for k, v in sorted(vars(step.params).items()): 

173 if not k.startswith('_'): 

174 if k in description['step_params']: 

175 out += '--> %s = %s\n' % (k, v) 

176 else: 

177 out += ' %s = %s\n' % (k, v) 

178 out += f' Number of steps: {step.status.time_size}\n' 

179 

180 out += ' Level: %s\n' % step.levels[0].__class__ 

181 for L in step.levels: 

182 out += ' Level %2i\n' % L.level_index 

183 for k, v in sorted(vars(L.params).items()): 

184 if not k.startswith('_'): 

185 if k in description['level_params']: 

186 out += '--> %s = %s\n' % (k, v) 

187 else: 

188 out += ' %s = %s\n' % (k, v) 

189 out += '--> Problem: %s\n' % L.prob.__class__ 

190 for k, v in sorted(L.prob.params.items()): 

191 if k in description['problem_params']: 

192 out += '--> %s = %s\n' % (k, v) 

193 else: 

194 out += ' %s = %s\n' % (k, v) 

195 out += '--> Data type u: %s\n' % L.prob.dtype_u 

196 out += '--> Data type f: %s\n' % L.prob.dtype_f 

197 out += '--> Sweeper: %s\n' % L.sweep.__class__ 

198 for k, v in sorted(vars(L.sweep.params).items()): 

199 if not k.startswith('_'): 

200 if k in description['sweeper_params']: 

201 out += '--> %s = %s\n' % (k, v) 

202 else: 

203 out += ' %s = %s\n' % (k, v) 

204 out += '--> Collocation: %s\n' % L.sweep.coll.__class__ 

205 

206 if len(step.levels) > 1: 

207 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer: 

208 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__ 

209 else: 

210 out += ' Base Transfer: %s\n' % step.base_transfer.__class__ 

211 for k, v in sorted(vars(step.base_transfer.params).items()): 

212 if not k.startswith('_'): 

213 if k in description['base_transfer_params']: 

214 out += '--> %s = %s\n' % (k, v) 

215 else: 

216 out += ' %s = %s\n' % (k, v) 

217 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__ 

218 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()): 

219 if not k.startswith('_'): 

220 if k in description['space_transfer_params']: 

221 out += '--> %s = %s\n' % (k, v) 

222 else: 

223 out += ' %s = %s\n' % (k, v) 

224 

225 out += '\n' 

226 out += self.get_convergence_controllers_as_table(description) 

227 out += '\n' 

228 self.logger.info(out) 

229 

230 out = '----------------------------------------------------------------------------------------------------' 

231 self.logger.info(out) 

232 out = 'Setup overview (--> user-defined, -> dependency) -- END\n' 

233 self.logger.info(out) 

234 

235 def run(self, u0, t0, Tend): 

236 """ 

237 Abstract interface to the run() method 

238 

239 Args: 

240 u0: initial values 

241 t0 (float): starting time 

242 Tend (float): ending time 

243 """ 

244 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)') 

245 

246 @property 

247 def hooks(self): 

248 """ 

249 Getter for the hooks 

250 

251 Returns: 

252 pySDC.Hooks.hooks: hooks 

253 """ 

254 return self.__hooks 

255 

256 def setup_convergence_controllers(self, description): 

257 ''' 

258 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing 

259 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration 

260 count or a residual based stopping criterion, as well as all convergence controllers added to the description. 

261 

262 Args: 

263 description (dict): The description object used to instantiate the controller 

264 

265 Returns: 

266 None 

267 ''' 

268 self.convergence_controllers = [] 

269 self.convergence_controller_order = [] 

270 conv_classes = description.get('convergence_controllers', {}) 

271 

272 # instantiate the convergence controllers 

273 for conv_class, params in conv_classes.items(): 

274 self.add_convergence_controller(conv_class, description=description, params=params) 

275 

276 return None 

277 

278 def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False): 

279 ''' 

280 Add an individual convergence controller to the list of convergence controllers and instantiate it. 

281 Afterwards, the order of the convergence controllers is updated. 

282 

283 Args: 

284 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added 

285 description (dict): The description object used to instantiate the controller 

286 params (dict): Parameters for the convergence controller 

287 allow_double (bool): Allow adding the same convergence controller multiple times 

288 

289 Returns: 

290 None 

291 ''' 

292 # check if we passed any sort of special params 

293 params = {**({} if params is None else params), 'useMPI': self.useMPI} 

294 

295 # check if we already have the convergence controller or if we want to have it multiple times 

296 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double: 

297 self.convergence_controllers.append(convergence_controller(self, params, description)) 

298 

299 # update ordering 

300 orders = [C.params.control_order for C in self.convergence_controllers] 

301 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)] 

302 

303 return None 

304 

305 def get_convergence_controllers_as_table(self, description): 

306 ''' 

307 This function is for debugging purposes to keep track of the different convergence controllers and their order. 

308 

309 Args: 

310 description (dict): Description of the problem 

311 

312 Returns: 

313 str: Table of convergence controllers as a string 

314 ''' 

315 out = 'Active convergence controllers:' 

316 out += '\n | # | order | convergence controller' 

317 out += '\n----+----+-------+---------------------------------------------------------------------------------------' 

318 for i in range(len(self.convergence_controllers)): 

319 C = self.convergence_controllers[self.convergence_controller_order[i]] 

320 

321 # figure out how the convergence controller was added 

322 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user 

323 user_added = '--> ' 

324 elif type(C) in self.base_convergence_controllers: # added by default 

325 user_added = ' ' 

326 else: # added as dependency 

327 user_added = ' -> ' 

328 

329 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' 

330 

331 return out 

332 

333 def return_stats(self): 

334 """ 

335 Return the merged stats from all hooks 

336 

337 Returns: 

338 dict: Merged stats from all hooks 

339 """ 

340 stats = {} 

341 for hook in self.hooks: 

342 stats = {**stats, **hook.return_stats()} 

343 return stats