Coverage for pySDC/core/controller.py: 99%

164 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-20 17:10 +0000

1import logging 

2import os 

3import sys 

4import numpy as np 

5 

6from pySDC.core.base_transfer import BaseTransfer 

7from pySDC.helpers.pysdc_helper import FrozenClass 

8from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

9from pySDC.implementations.hooks.default_hook import DefaultHooks 

10 

11 

12# short helper class to add params as attributes 

13class _Pars(FrozenClass): 

14 def __init__(self, params): 

15 self.mssdc_jac = True 

16 self.predict_type = None 

17 self.all_to_done = False 

18 self.logger_level = 20 

19 self.log_to_file = False 

20 self.dump_setup = True 

21 self.fname = 'run_pid' + str(os.getpid()) + '.log' 

22 self.use_iteration_estimator = False 

23 

24 for k, v in params.items(): 

25 setattr(self, k, v) 

26 

27 self._freeze() 

28 

29 

30class Controller(object): 

31 """ 

32 Base abstract controller class 

33 """ 

34 

35 def __init__(self, controller_params, description, useMPI=None): 

36 """ 

37 Initialization routine for the base controller 

38 

39 Args: 

40 controller_params (dict): parameter set for the controller and the steps 

41 """ 

42 self.useMPI = useMPI 

43 

44 # check if we have a hook on this list. If not, use default class. 

45 self.__hooks = [] 

46 hook_classes = [DefaultHooks] 

47 user_hooks = controller_params.get('hook_class', []) 

48 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] 

49 [self.add_hook(hook) for hook in hook_classes] 

50 controller_params['hook_class'] = hook_classes 

51 

52 for hook in self.hooks: 

53 hook.pre_setup(step=None, level_number=None) 

54 

55 self.params = _Pars(controller_params) 

56 

57 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname) 

58 self.logger = logging.getLogger('controller') 

59 

60 if self.params.use_iteration_estimator and self.params.all_to_done: 

61 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done') 

62 

63 self.base_convergence_controllers = [CheckConvergence] 

64 self.setup_convergence_controllers(description) 

65 

66 @staticmethod 

67 def __setup_custom_logger(level=None, log_to_file=None, fname=None): 

68 """ 

69 Helper function to set main parameters for the logging facility 

70 

71 Args: 

72 level (int): level of logging 

73 log_to_file (bool): flag to turn on/off logging to file 

74 fname (str): 

75 """ 

76 

77 assert type(level) is int 

78 

79 # specify formats and handlers 

80 if log_to_file: 

81 file_formatter = logging.Formatter( 

82 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s' 

83 ) 

84 if os.path.isfile(fname): 

85 file_handler = logging.FileHandler(fname, mode='a') 

86 else: 

87 file_handler = logging.FileHandler(fname, mode='w') 

88 file_handler.setFormatter(file_formatter) 

89 else: 

90 file_handler = None 

91 

92 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s') 

93 std_handler = logging.StreamHandler(sys.stdout) 

94 std_handler.setFormatter(std_formatter) 

95 

96 # instantiate logger 

97 logger = logging.getLogger('') 

98 

99 # remove handlers from previous calls to controller 

100 for handler in logger.handlers[:]: 

101 logger.removeHandler(handler) 

102 

103 logger.setLevel(level) 

104 logger.addHandler(std_handler) 

105 if log_to_file: 

106 logger.addHandler(file_handler) 

107 else: 

108 pass 

109 

110 def add_hook(self, hook): 

111 """ 

112 Add a hook to the controller which will be called in addition to all other hooks whenever something happens. 

113 The hook is only added if a hook of the same class is not already present. 

114 

115 Args: 

116 hook (pySDC.Hook): A hook class that is derived from the core hook class 

117 

118 Returns: 

119 None 

120 """ 

121 if hook not in [type(me) for me in self.hooks]: 

122 self.__hooks += [hook()] 

123 

124 def welcome_message(self): 

125 out = ( 

126 "Welcome to the one and only, really very astonishing and 87.3% bug free" 

127 + "\n" 

128 + r" _____ _____ _____ " 

129 + "\n" 

130 + r" / ____| __ \ / ____|" 

131 + "\n" 

132 + r" _ __ _ _| (___ | | | | | " 

133 + "\n" 

134 + r" | '_ \| | | |\___ \| | | | | " 

135 + "\n" 

136 + r" | |_) | |_| |____) | |__| | |____ " 

137 + "\n" 

138 + r" | .__/ \__, |_____/|_____/ \_____|" 

139 + "\n" 

140 + r" | | __/ | " 

141 + "\n" 

142 + r" |_| |___/ " 

143 + "\n" 

144 + r" " 

145 ) 

146 self.logger.info(out) 

147 

148 def dump_setup(self, step, controller_params, description): 

149 """ 

150 Helper function to dump the setup used for this controller 

151 

152 Args: 

153 step (pySDC.Step.step): the step instance (will/should be the first one only) 

154 controller_params (dict): controller parameters 

155 description (dict): description of the problem 

156 """ 

157 

158 self.welcome_message() 

159 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN' 

160 self.logger.info(out) 

161 out = '----------------------------------------------------------------------------------------------------\n\n' 

162 out += 'Controller: %s\n' % self.__class__ 

163 for k, v in sorted(vars(self.params).items()): 

164 if not k.startswith('_'): 

165 if k in controller_params: 

166 out += '--> %s = %s\n' % (k, v) 

167 else: 

168 out += ' %s = %s\n' % (k, v) 

169 

170 out += '\nStep: %s\n' % step.__class__ 

171 for k, v in sorted(vars(step.params).items()): 

172 if not k.startswith('_'): 

173 if k in description['step_params']: 

174 out += '--> %s = %s\n' % (k, v) 

175 else: 

176 out += ' %s = %s\n' % (k, v) 

177 out += f' Number of steps: {step.status.time_size}\n' 

178 

179 out += ' Level: %s\n' % step.levels[0].__class__ 

180 for L in step.levels: 

181 out += ' Level %2i\n' % L.level_index 

182 for k, v in sorted(vars(L.params).items()): 

183 if not k.startswith('_'): 

184 if k in description['level_params']: 

185 out += '--> %s = %s\n' % (k, v) 

186 else: 

187 out += ' %s = %s\n' % (k, v) 

188 out += '--> Problem: %s\n' % L.prob.__class__ 

189 for k, v in sorted(L.prob.params.items()): 

190 if k in description['problem_params']: 

191 out += '--> %s = %s\n' % (k, v) 

192 else: 

193 out += ' %s = %s\n' % (k, v) 

194 out += '--> Data type u: %s\n' % L.prob.dtype_u 

195 out += '--> Data type f: %s\n' % L.prob.dtype_f 

196 out += '--> Sweeper: %s\n' % L.sweep.__class__ 

197 for k, v in sorted(vars(L.sweep.params).items()): 

198 if not k.startswith('_'): 

199 if k in description['sweeper_params']: 

200 out += '--> %s = %s\n' % (k, v) 

201 else: 

202 out += ' %s = %s\n' % (k, v) 

203 out += '--> Collocation: %s\n' % L.sweep.coll.__class__ 

204 

205 if len(step.levels) > 1: 

206 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer: 

207 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__ 

208 else: 

209 out += ' Base Transfer: %s\n' % step.base_transfer.__class__ 

210 for k, v in sorted(vars(step.base_transfer.params).items()): 

211 if not k.startswith('_'): 

212 if k in description['base_transfer_params']: 

213 out += '--> %s = %s\n' % (k, v) 

214 else: 

215 out += ' %s = %s\n' % (k, v) 

216 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__ 

217 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()): 

218 if not k.startswith('_'): 

219 if k in description['space_transfer_params']: 

220 out += '--> %s = %s\n' % (k, v) 

221 else: 

222 out += ' %s = %s\n' % (k, v) 

223 

224 out += '\n' 

225 out += self.get_convergence_controllers_as_table(description) 

226 out += '\n' 

227 self.logger.info(out) 

228 

229 out = '----------------------------------------------------------------------------------------------------' 

230 self.logger.info(out) 

231 out = 'Setup overview (--> user-defined, -> dependency) -- END\n' 

232 self.logger.info(out) 

233 

234 def run(self, u0, t0, Tend): 

235 """ 

236 Abstract interface to the run() method 

237 

238 Args: 

239 u0: initial values 

240 t0 (float): starting time 

241 Tend (float): ending time 

242 """ 

243 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)') 

244 

245 @property 

246 def hooks(self): 

247 """ 

248 Getter for the hooks 

249 

250 Returns: 

251 pySDC.Hooks.hooks: hooks 

252 """ 

253 return self.__hooks 

254 

255 def setup_convergence_controllers(self, description): 

256 ''' 

257 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing 

258 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration 

259 count or a residual based stopping criterion, as well as all convergence controllers added to the description. 

260 

261 Args: 

262 description (dict): The description object used to instantiate the controller 

263 

264 Returns: 

265 None 

266 ''' 

267 self.convergence_controllers = [] 

268 self.convergence_controller_order = [] 

269 conv_classes = description.get('convergence_controllers', {}) 

270 

271 # instantiate the convergence controllers 

272 for conv_class, params in conv_classes.items(): 

273 self.add_convergence_controller(conv_class, description=description, params=params) 

274 

275 return None 

276 

277 def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False): 

278 ''' 

279 Add an individual convergence controller to the list of convergence controllers and instantiate it. 

280 Afterwards, the order of the convergence controllers is updated. 

281 

282 Args: 

283 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added 

284 description (dict): The description object used to instantiate the controller 

285 params (dict): Parameters for the convergence controller 

286 allow_double (bool): Allow adding the same convergence controller multiple times 

287 

288 Returns: 

289 None 

290 ''' 

291 # check if we passed any sort of special params 

292 params = {**({} if params is None else params), 'useMPI': self.useMPI} 

293 

294 # check if we already have the convergence controller or if we want to have it multiple times 

295 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double: 

296 self.convergence_controllers.append(convergence_controller(self, params, description)) 

297 

298 # update ordering 

299 orders = [C.params.control_order for C in self.convergence_controllers] 

300 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)] 

301 

302 return None 

303 

304 def get_convergence_controllers_as_table(self, description): 

305 ''' 

306 This function is for debugging purposes to keep track of the different convergence controllers and their order. 

307 

308 Args: 

309 description (dict): Description of the problem 

310 

311 Returns: 

312 str: Table of convergence controllers as a string 

313 ''' 

314 out = 'Active convergence controllers:' 

315 out += '\n | # | order | convergence controller' 

316 out += '\n----+----+-------+---------------------------------------------------------------------------------------' 

317 for i in range(len(self.convergence_controllers)): 

318 C = self.convergence_controllers[self.convergence_controller_order[i]] 

319 

320 # figure out how the convergence controller was added 

321 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user 

322 user_added = '--> ' 

323 elif type(C) in self.base_convergence_controllers: # added by default 

324 user_added = ' ' 

325 else: # added as dependency 

326 user_added = ' -> ' 

327 

328 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' 

329 

330 return out 

331 

332 def return_stats(self): 

333 """ 

334 Return the merged stats from all hooks 

335 

336 Returns: 

337 dict: Merged stats from all hooks 

338 """ 

339 stats = {} 

340 for hook in self.hooks: 

341 stats = {**stats, **hook.return_stats()} 

342 return stats