Coverage for pySDC/core/controller.py: 98%

198 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-04 15:08 +0000

1import logging 

2import os 

3import sys 

4import numpy as np 

5 

6from pySDC.core.base_transfer import BaseTransfer 

7from pySDC.helpers.pysdc_helper import FrozenClass 

8from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

9from pySDC.implementations.hooks.default_hook import DefaultHooks 

10from pySDC.implementations.hooks.log_timings import CPUTimings 

11 

12 

13# short helper class to add params as attributes 

14class _Pars(FrozenClass): 

15 def __init__(self, params): 

16 self.mssdc_jac = True 

17 self.predict_type = None 

18 self.all_to_done = False 

19 self.logger_level = 20 

20 self.log_to_file = False 

21 self.dump_setup = True 

22 self.fname = 'run_pid' + str(os.getpid()) + '.log' 

23 self.use_iteration_estimator = False 

24 

25 for k, v in params.items(): 

26 setattr(self, k, v) 

27 

28 self._freeze() 

29 

30 

31class Controller(object): 

32 """ 

33 Base abstract controller class 

34 """ 

35 

36 def __init__(self, controller_params, description, useMPI=None): 

37 """ 

38 Initialization routine for the base controller 

39 

40 Args: 

41 controller_params (dict): parameter set for the controller and the steps 

42 """ 

43 self.useMPI = useMPI 

44 self.description = description 

45 

46 # check if we have a hook on this list. If not, use default class. 

47 self.__hooks = [] 

48 hook_classes = [DefaultHooks, CPUTimings] 

49 user_hooks = controller_params.get('hook_class', []) 

50 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] 

51 [self.add_hook(hook) for hook in hook_classes] 

52 controller_params['hook_class'] = hook_classes 

53 

54 for hook in self.hooks: 

55 hook.pre_setup(step=None, level_number=None) 

56 

57 self.params = _Pars(controller_params) 

58 

59 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname) 

60 self.logger = logging.getLogger('controller') 

61 

62 if self.params.use_iteration_estimator and self.params.all_to_done: 

63 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done') 

64 

65 self.base_convergence_controllers = [CheckConvergence] 

66 self.setup_convergence_controllers(description) 

67 

68 @staticmethod 

69 def __setup_custom_logger(level=None, log_to_file=None, fname=None): 

70 """ 

71 Helper function to set main parameters for the logging facility 

72 

73 Args: 

74 level (int): level of logging 

75 log_to_file (bool): flag to turn on/off logging to file 

76 fname (str): 

77 """ 

78 

79 assert type(level) is int 

80 

81 # specify formats and handlers 

82 if log_to_file: 

83 file_formatter = logging.Formatter( 

84 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s' 

85 ) 

86 if os.path.isfile(fname): 

87 file_handler = logging.FileHandler(fname, mode='a') 

88 else: 

89 file_handler = logging.FileHandler(fname, mode='w') 

90 file_handler.setFormatter(file_formatter) 

91 else: 

92 file_handler = None 

93 

94 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s') 

95 

96 if level <= logging.DEBUG: 

97 import warnings 

98 

99 warnings.warn('Running with debug output will degrade performance as all output is immediately flushed.') 

100 

101 class StreamFlushingHandler(logging.StreamHandler): 

102 """ 

103 This will immediately flush any messages to the output. 

104 """ 

105 

106 def emit(self, record): 

107 super().emit(record) 

108 self.flush() 

109 

110 std_handler = StreamFlushingHandler(sys.stdout) 

111 else: 

112 std_handler = logging.StreamHandler(sys.stdout) 

113 

114 std_handler.setFormatter(std_formatter) 

115 

116 # instantiate logger 

117 logger = logging.getLogger('') 

118 

119 # remove handlers from previous calls to controller 

120 for handler in logger.handlers[:]: 

121 logger.removeHandler(handler) 

122 

123 logger.setLevel(level) 

124 logger.addHandler(std_handler) 

125 if log_to_file: 

126 logger.addHandler(file_handler) 

127 else: 

128 pass 

129 

130 def add_hook(self, hook): 

131 """ 

132 Add a hook to the controller which will be called in addition to all other hooks whenever something happens. 

133 The hook is only added if a hook of the same class is not already present. 

134 

135 Args: 

136 hook (pySDC.Hook): A hook class that is derived from the core hook class 

137 

138 Returns: 

139 None 

140 """ 

141 if hook not in [type(me) for me in self.hooks]: 

142 self.__hooks += [hook()] 

143 

144 def welcome_message(self): 

145 out = ( 

146 "Welcome to the one and only, really very astonishing and 87.3% bug free" 

147 + "\n" 

148 + r" _____ _____ _____ " 

149 + "\n" 

150 + r" / ____| __ \ / ____|" 

151 + "\n" 

152 + r" _ __ _ _| (___ | | | | | " 

153 + "\n" 

154 + r" | '_ \| | | |\___ \| | | | | " 

155 + "\n" 

156 + r" | |_) | |_| |____) | |__| | |____ " 

157 + "\n" 

158 + r" | .__/ \__, |_____/|_____/ \_____|" 

159 + "\n" 

160 + r" | | __/ | " 

161 + "\n" 

162 + r" |_| |___/ " 

163 + "\n" 

164 + r" " 

165 ) 

166 self.logger.info(out) 

167 

168 def dump_setup(self, step, controller_params, description): 

169 """ 

170 Helper function to dump the setup used for this controller 

171 

172 Args: 

173 step (pySDC.Step.step): the step instance (will/should be the first one only) 

174 controller_params (dict): controller parameters 

175 description (dict): description of the problem 

176 """ 

177 

178 self.welcome_message() 

179 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN' 

180 self.logger.info(out) 

181 out = '----------------------------------------------------------------------------------------------------\n\n' 

182 out += 'Controller: %s\n' % self.__class__ 

183 for k, v in sorted(vars(self.params).items()): 

184 if not k.startswith('_'): 

185 if k in controller_params: 

186 out += '--> %s = %s\n' % (k, v) 

187 else: 

188 out += ' %s = %s\n' % (k, v) 

189 

190 out += '\nStep: %s\n' % step.__class__ 

191 for k, v in sorted(vars(step.params).items()): 

192 if not k.startswith('_'): 

193 if k in description['step_params']: 

194 out += '--> %s = %s\n' % (k, v) 

195 else: 

196 out += ' %s = %s\n' % (k, v) 

197 out += f' Number of steps: {step.status.time_size}\n' 

198 

199 out += ' Level: %s\n' % step.levels[0].__class__ 

200 for L in step.levels: 

201 out += ' Level %2i\n' % L.level_index 

202 for k, v in sorted(vars(L.params).items()): 

203 if not k.startswith('_'): 

204 if k in description['level_params']: 

205 out += '--> %s = %s\n' % (k, v) 

206 else: 

207 out += ' %s = %s\n' % (k, v) 

208 out += '--> Problem: %s\n' % L.prob.__class__ 

209 for k, v in sorted(L.prob.params.items()): 

210 if k in description['problem_params']: 

211 out += '--> %s = %s\n' % (k, v) 

212 else: 

213 out += ' %s = %s\n' % (k, v) 

214 out += '--> Data type u: %s\n' % L.prob.dtype_u 

215 out += '--> Data type f: %s\n' % L.prob.dtype_f 

216 out += '--> Sweeper: %s\n' % L.sweep.__class__ 

217 for k, v in sorted(vars(L.sweep.params).items()): 

218 if not k.startswith('_'): 

219 if k in description['sweeper_params']: 

220 out += '--> %s = %s\n' % (k, v) 

221 else: 

222 out += ' %s = %s\n' % (k, v) 

223 out += '--> Collocation: %s\n' % L.sweep.coll.__class__ 

224 

225 if len(step.levels) > 1: 

226 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer: 

227 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__ 

228 else: 

229 out += ' Base Transfer: %s\n' % step.base_transfer.__class__ 

230 for k, v in sorted(vars(step.base_transfer.params).items()): 

231 if not k.startswith('_'): 

232 if k in description['base_transfer_params']: 

233 out += '--> %s = %s\n' % (k, v) 

234 else: 

235 out += ' %s = %s\n' % (k, v) 

236 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__ 

237 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()): 

238 if not k.startswith('_'): 

239 if k in description['space_transfer_params']: 

240 out += '--> %s = %s\n' % (k, v) 

241 else: 

242 out += ' %s = %s\n' % (k, v) 

243 

244 out += '\n' 

245 out += self.get_convergence_controllers_as_table(description) 

246 out += '\n' 

247 self.logger.info(out) 

248 

249 out = '----------------------------------------------------------------------------------------------------' 

250 self.logger.info(out) 

251 out = 'Setup overview (--> user-defined, -> dependency) -- END\n' 

252 self.logger.info(out) 

253 

254 def run(self, u0, t0, Tend): 

255 """ 

256 Abstract interface to the run() method 

257 

258 Args: 

259 u0: initial values 

260 t0 (float): starting time 

261 Tend (float): ending time 

262 """ 

263 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)') 

264 

265 @property 

266 def hooks(self): 

267 """ 

268 Getter for the hooks 

269 

270 Returns: 

271 pySDC.Hooks.hooks: hooks 

272 """ 

273 return self.__hooks 

274 

275 def setup_convergence_controllers(self, description): 

276 ''' 

277 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing 

278 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration 

279 count or a residual based stopping criterion, as well as all convergence controllers added to the description. 

280 

281 Args: 

282 description (dict): The description object used to instantiate the controller 

283 

284 Returns: 

285 None 

286 ''' 

287 self.convergence_controllers = [] 

288 self.convergence_controller_order = [] 

289 conv_classes = description.get('convergence_controllers', {}) 

290 

291 # instantiate the convergence controllers 

292 for conv_class, params in conv_classes.items(): 

293 self.add_convergence_controller(conv_class, description=description, params=params) 

294 

295 return None 

296 

297 def add_convergence_controller(self, convergence_controller, description, params=None, allow_double=False): 

298 ''' 

299 Add an individual convergence controller to the list of convergence controllers and instantiate it. 

300 Afterwards, the order of the convergence controllers is updated. 

301 

302 Args: 

303 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added 

304 description (dict): The description object used to instantiate the controller 

305 params (dict): Parameters for the convergence controller 

306 allow_double (bool): Allow adding the same convergence controller multiple times 

307 

308 Returns: 

309 None 

310 ''' 

311 # check if we passed any sort of special params 

312 params = {**({} if params is None else params), 'useMPI': self.useMPI} 

313 

314 # check if we already have the convergence controller or if we want to have it multiple times 

315 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double: 

316 self.convergence_controllers.append(convergence_controller(self, params, description)) 

317 

318 # update ordering 

319 orders = [C.params.control_order for C in self.convergence_controllers] 

320 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)] 

321 

322 return None 

323 

324 def get_convergence_controllers_as_table(self, description): 

325 ''' 

326 This function is for debugging purposes to keep track of the different convergence controllers and their order. 

327 

328 Args: 

329 description (dict): Description of the problem 

330 

331 Returns: 

332 str: Table of convergence controllers as a string 

333 ''' 

334 out = 'Active convergence controllers:' 

335 out += '\n | # | order | convergence controller' 

336 out += '\n----+----+-------+---------------------------------------------------------------------------------------' 

337 for i in range(len(self.convergence_controllers)): 

338 C = self.convergence_controllers[self.convergence_controller_order[i]] 

339 

340 # figure out how the convergence controller was added 

341 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user 

342 user_added = '--> ' 

343 elif type(C) in self.base_convergence_controllers: # added by default 

344 user_added = ' ' 

345 else: # added as dependency 

346 user_added = ' -> ' 

347 

348 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' 

349 

350 return out 

351 

352 def return_stats(self): 

353 """ 

354 Return the merged stats from all hooks 

355 

356 Returns: 

357 dict: Merged stats from all hooks 

358 """ 

359 stats = {} 

360 for hook in self.hooks: 

361 stats = {**stats, **hook.return_stats()} 

362 return stats 

363 

364 

365class ParaDiagController(Controller): 

366 

367 def __init__(self, controller_params, description, n_steps, useMPI=None): 

368 """ 

369 Initialization routine for ParaDiag controllers 

370 

371 Args: 

372 num_procs: number of parallel time steps (still serial, though), can be 1 

373 controller_params: parameter set for the controller and the steps 

374 description: all the parameters to set up the rest (levels, problems, transfer, ...) 

375 n_steps (int): Number of parallel steps 

376 alpha (float): alpha parameter for ParaDiag 

377 """ 

378 from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization 

379 

380 if QDiagonalization in description['sweeper_class'].__mro__: 

381 description['sweeper_params']['ignore_ic'] = True 

382 description['sweeper_params']['update_f_evals'] = False 

383 else: 

384 logging.getLogger('controller').warning( 

385 f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.' 

386 ) 

387 

388 if controller_params.get('all_to_done', False): 

389 raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`') 

390 if 'alpha' not in controller_params.keys(): 

391 from pySDC.core.errors import ParameterError 

392 

393 raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!') 

394 controller_params['average_jacobian'] = controller_params.get('average_jacobian', True) 

395 

396 controller_params['all_to_done'] = True 

397 super().__init__(controller_params=controller_params, description=description, useMPI=useMPI) 

398 

399 self.n_steps = n_steps 

400 

401 def FFT_in_time(self, quantity): 

402 """ 

403 Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag 

404 

405 Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT 

406 with transposes! 

407 """ 

408 if not hasattr(self, '__FFT_matrix'): 

409 from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix 

410 

411 self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha) 

412 

413 self.apply_matrix(self.__FFT_matrix, quantity) 

414 

415 def iFFT_in_time(self, quantity): 

416 """ 

417 Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag 

418 """ 

419 if not hasattr(self, '__iFFT_matrix'): 

420 from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix 

421 

422 self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha) 

423 

424 self.apply_matrix(self.__iFFT_matrix, quantity)