Coverage for pySDC / core / controller.py: 98%

199 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-20 16:04 +0000

1import logging 

2import os 

3import sys 

4from typing import Any, Dict, List, Optional, Type, Union 

5import numpy as np 

6 

7from pySDC.core.base_transfer import BaseTransfer 

8from pySDC.helpers.pysdc_helper import FrozenClass 

9from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence 

10from pySDC.implementations.hooks.default_hook import DefaultHooks 

11from pySDC.implementations.hooks.log_timings import CPUTimings 

12 

13 

14# short helper class to add params as attributes 

15class _Pars(FrozenClass): 

16 def __init__(self, params: Dict[str, Any]) -> None: 

17 self.mssdc_jac: bool = True 

18 self.predict_type: Optional[str] = None 

19 self.all_to_done: bool = False 

20 self.logger_level: int = 20 

21 self.log_to_file: bool = False 

22 self.dump_setup: bool = True 

23 self.fname: str = 'run_pid' + str(os.getpid()) + '.log' 

24 self.use_iteration_estimator: bool = False 

25 

26 for k, v in params.items(): 

27 setattr(self, k, v) 

28 

29 self._freeze() 

30 

31 

32class Controller(object): 

33 """ 

34 Base abstract controller class 

35 """ 

36 

37 def __init__( 

38 self, controller_params: Dict[str, Any], description: Dict[str, Any], useMPI: Optional[bool] = None 

39 ) -> None: 

40 """ 

41 Initialization routine for the base controller 

42 

43 Args: 

44 controller_params (dict): parameter set for the controller and the steps 

45 """ 

46 self.useMPI: Optional[bool] = useMPI 

47 self.description: Dict[str, Any] = description 

48 

49 # check if we have a hook on this list. If not, use default class. 

50 self.__hooks: List[Any] = [] 

51 hook_classes: List[Type[Any]] = [DefaultHooks, CPUTimings] 

52 user_hooks = controller_params.get('hook_class', []) 

53 hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] 

54 [self.add_hook(hook) for hook in hook_classes] 

55 controller_params['hook_class'] = hook_classes 

56 

57 for hook in self.hooks: 

58 hook.pre_setup(step=None, level_number=None) 

59 

60 self.params: _Pars = _Pars(controller_params) 

61 

62 self.__setup_custom_logger(self.params.logger_level, self.params.log_to_file, self.params.fname) 

63 self.logger: logging.Logger = logging.getLogger('controller') 

64 

65 if self.params.use_iteration_estimator and self.params.all_to_done: 

66 self.logger.warning('all_to_done and use_iteration_estimator set, will ignore all_to_done') 

67 

68 self.base_convergence_controllers: List[Type[Any]] = [CheckConvergence] 

69 self.setup_convergence_controllers(description) 

70 

71 @staticmethod 

72 def __setup_custom_logger( 

73 level: Optional[int] = None, log_to_file: Optional[bool] = None, fname: Optional[str] = None 

74 ) -> None: 

75 """ 

76 Helper function to set main parameters for the logging facility 

77 

78 Args: 

79 level (int): level of logging 

80 log_to_file (bool): flag to turn on/off logging to file 

81 fname (str): 

82 """ 

83 

84 assert type(level) is int 

85 

86 # specify formats and handlers 

87 if log_to_file: 

88 file_formatter = logging.Formatter( 

89 fmt='%(asctime)s - %(name)s - %(module)s - %(funcName)s - %(lineno)d - %(levelname)s: %(message)s' 

90 ) 

91 if os.path.isfile(fname): 

92 file_handler = logging.FileHandler(fname, mode='a') 

93 else: 

94 file_handler = logging.FileHandler(fname, mode='w') 

95 file_handler.setFormatter(file_formatter) 

96 else: 

97 file_handler = None 

98 

99 std_formatter = logging.Formatter(fmt='%(name)s - %(levelname)s: %(message)s') 

100 

101 if level <= logging.DEBUG: 

102 import warnings 

103 

104 warnings.warn('Running with debug output will degrade performance as all output is immediately flushed.') 

105 

106 class StreamFlushingHandler(logging.StreamHandler): 

107 """ 

108 This will immediately flush any messages to the output. 

109 """ 

110 

111 def emit(self, record: logging.LogRecord) -> None: 

112 super().emit(record) 

113 self.flush() 

114 

115 std_handler = StreamFlushingHandler(sys.stdout) 

116 else: 

117 std_handler = logging.StreamHandler(sys.stdout) 

118 

119 std_handler.setFormatter(std_formatter) 

120 

121 # instantiate logger 

122 logger = logging.getLogger('') 

123 

124 # remove handlers from previous calls to controller 

125 for handler in logger.handlers[:]: 

126 logger.removeHandler(handler) 

127 

128 logger.setLevel(level) 

129 logger.addHandler(std_handler) 

130 if log_to_file: 

131 logger.addHandler(file_handler) 

132 else: 

133 pass 

134 

135 def add_hook(self, hook: Type[Any]) -> None: 

136 """ 

137 Add a hook to the controller which will be called in addition to all other hooks whenever something happens. 

138 The hook is only added if a hook of the same class is not already present. 

139 

140 Args: 

141 hook (pySDC.Hook): A hook class that is derived from the core hook class 

142 

143 Returns: 

144 None 

145 """ 

146 if hook not in [type(me) for me in self.hooks]: 

147 self.__hooks += [hook()] 

148 

149 def welcome_message(self) -> None: 

150 out = ( 

151 "Welcome to the one and only, really very astonishing and 87.3% bug free" 

152 + "\n" 

153 + r" _____ _____ _____ " 

154 + "\n" 

155 + r" / ____| __ \ / ____|" 

156 + "\n" 

157 + r" _ __ _ _| (___ | | | | | " 

158 + "\n" 

159 + r" | '_ \| | | |\___ \| | | | | " 

160 + "\n" 

161 + r" | |_) | |_| |____) | |__| | |____ " 

162 + "\n" 

163 + r" | .__/ \__, |_____/|_____/ \_____|" 

164 + "\n" 

165 + r" | | __/ | " 

166 + "\n" 

167 + r" |_| |___/ " 

168 + "\n" 

169 + r" " 

170 ) 

171 self.logger.info(out) 

172 

173 def dump_setup(self, step: Any, controller_params: Dict[str, Any], description: Dict[str, Any]) -> None: 

174 """ 

175 Helper function to dump the setup used for this controller 

176 

177 Args: 

178 step (pySDC.Step.step): the step instance (will/should be the first one only) 

179 controller_params (dict): controller parameters 

180 description (dict): description of the problem 

181 """ 

182 

183 self.welcome_message() 

184 out = 'Setup overview (--> user-defined, -> dependency) -- BEGIN' 

185 self.logger.info(out) 

186 out = '----------------------------------------------------------------------------------------------------\n\n' 

187 out += 'Controller: %s\n' % self.__class__ 

188 for k, v in sorted(vars(self.params).items()): 

189 if not k.startswith('_'): 

190 if k in controller_params: 

191 out += '--> %s = %s\n' % (k, v) 

192 else: 

193 out += ' %s = %s\n' % (k, v) 

194 

195 out += '\nStep: %s\n' % step.__class__ 

196 for k, v in sorted(vars(step.params).items()): 

197 if not k.startswith('_'): 

198 if k in description['step_params']: 

199 out += '--> %s = %s\n' % (k, v) 

200 else: 

201 out += ' %s = %s\n' % (k, v) 

202 out += f' Number of steps: {step.status.time_size}\n' 

203 

204 out += ' Level: %s\n' % step.levels[0].__class__ 

205 for L in step.levels: 

206 out += ' Level %2i\n' % L.level_index 

207 for k, v in sorted(vars(L.params).items()): 

208 if not k.startswith('_'): 

209 if k in description['level_params']: 

210 out += '--> %s = %s\n' % (k, v) 

211 else: 

212 out += ' %s = %s\n' % (k, v) 

213 out += '--> Problem: %s\n' % L.prob.__class__ 

214 for k, v in sorted(L.prob.params.items()): 

215 if k in description['problem_params']: 

216 out += '--> %s = %s\n' % (k, v) 

217 else: 

218 out += ' %s = %s\n' % (k, v) 

219 out += '--> Data type u: %s\n' % L.prob.dtype_u 

220 out += '--> Data type f: %s\n' % L.prob.dtype_f 

221 out += '--> Sweeper: %s\n' % L.sweep.__class__ 

222 for k, v in sorted(vars(L.sweep.params).items()): 

223 if not k.startswith('_'): 

224 if k in description['sweeper_params']: 

225 out += '--> %s = %s\n' % (k, v) 

226 else: 

227 out += ' %s = %s\n' % (k, v) 

228 out += '--> Collocation: %s\n' % L.sweep.coll.__class__ 

229 

230 if len(step.levels) > 1: 

231 if 'base_transfer_class' in description and description['base_transfer_class'] is not BaseTransfer: 

232 out += '--> Base Transfer: %s\n' % step.base_transfer.__class__ 

233 else: 

234 out += ' Base Transfer: %s\n' % step.base_transfer.__class__ 

235 for k, v in sorted(vars(step.base_transfer.params).items()): 

236 if not k.startswith('_'): 

237 if k in description['base_transfer_params']: 

238 out += '--> %s = %s\n' % (k, v) 

239 else: 

240 out += ' %s = %s\n' % (k, v) 

241 out += '--> Space Transfer: %s\n' % step.base_transfer.space_transfer.__class__ 

242 for k, v in sorted(vars(step.base_transfer.space_transfer.params).items()): 

243 if not k.startswith('_'): 

244 if k in description['space_transfer_params']: 

245 out += '--> %s = %s\n' % (k, v) 

246 else: 

247 out += ' %s = %s\n' % (k, v) 

248 

249 out += '\n' 

250 out += self.get_convergence_controllers_as_table(description) 

251 out += '\n' 

252 self.logger.info(out) 

253 

254 out = '----------------------------------------------------------------------------------------------------' 

255 self.logger.info(out) 

256 out = 'Setup overview (--> user-defined, -> dependency) -- END\n' 

257 self.logger.info(out) 

258 

259 def run(self, u0: Any, t0: float, Tend: float) -> Any: 

260 """ 

261 Abstract interface to the run() method 

262 

263 Args: 

264 u0: initial values 

265 t0 (float): starting time 

266 Tend (float): ending time 

267 """ 

268 raise NotImplementedError('ERROR: controller has to implement run(self, u0, t0, Tend)') 

269 

270 @property 

271 def hooks(self) -> List[Any]: 

272 """ 

273 Getter for the hooks 

274 

275 Returns: 

276 pySDC.Hooks.hooks: hooks 

277 """ 

278 return self.__hooks 

279 

280 def setup_convergence_controllers(self, description: Dict[str, Any]) -> None: 

281 ''' 

282 Setup variables needed for convergence controllers, notably a list containing all of them and a list containing 

283 their order. Also, we add the `CheckConvergence` convergence controller, which takes care of maximum iteration 

284 count or a residual based stopping criterion, as well as all convergence controllers added to the description. 

285 

286 Args: 

287 description (dict): The description object used to instantiate the controller 

288 

289 Returns: 

290 None 

291 ''' 

292 self.convergence_controllers: List[Any] = [] 

293 # List of indices specifying the order of convergence controllers 

294 self.convergence_controller_order: List[int] = [] 

295 conv_classes = description.get('convergence_controllers', {}) 

296 

297 # instantiate the convergence controllers 

298 for conv_class, params in conv_classes.items(): 

299 self.add_convergence_controller(conv_class, description=description, params=params) 

300 

301 return None 

302 

303 def add_convergence_controller( 

304 self, 

305 convergence_controller: Type[Any], 

306 description: Dict[str, Any], 

307 params: Optional[Dict[str, Any]] = None, 

308 allow_double: bool = False, 

309 ) -> None: 

310 ''' 

311 Add an individual convergence controller to the list of convergence controllers and instantiate it. 

312 Afterwards, the order of the convergence controllers is updated. 

313 

314 Args: 

315 convergence_controller (pySDC.ConvergenceController): The convergence controller to be added 

316 description (dict): The description object used to instantiate the controller 

317 params (dict): Parameters for the convergence controller 

318 allow_double (bool): Allow adding the same convergence controller multiple times 

319 

320 Returns: 

321 None 

322 ''' 

323 # check if we passed any sort of special params 

324 params = {**({} if params is None else params), 'useMPI': self.useMPI} 

325 

326 # check if we already have the convergence controller or if we want to have it multiple times 

327 if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double: 

328 self.convergence_controllers.append(convergence_controller(self, params, description)) 

329 

330 # update ordering 

331 orders = [C.params.control_order for C in self.convergence_controllers] 

332 self.convergence_controller_order = np.arange(len(self.convergence_controllers))[np.argsort(orders)] 

333 

334 return None 

335 

336 def get_convergence_controllers_as_table(self, description: Dict[str, Any]) -> str: 

337 ''' 

338 This function is for debugging purposes to keep track of the different convergence controllers and their order. 

339 

340 Args: 

341 description (dict): Description of the problem 

342 

343 Returns: 

344 str: Table of convergence controllers as a string 

345 ''' 

346 out = 'Active convergence controllers:' 

347 out += '\n | # | order | convergence controller' 

348 out += '\n----+----+-------+---------------------------------------------------------------------------------------' 

349 for i in range(len(self.convergence_controllers)): 

350 C = self.convergence_controllers[self.convergence_controller_order[i]] 

351 

352 # figure out how the convergence controller was added 

353 if type(C) in description.get('convergence_controllers', {}).keys(): # added by user 

354 user_added = '--> ' 

355 elif type(C) in self.base_convergence_controllers: # added by default 

356 user_added = ' ' 

357 else: # added as dependency 

358 user_added = ' -> ' 

359 

360 out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' 

361 

362 return out 

363 

364 def return_stats(self) -> Dict[Any, Any]: 

365 """ 

366 Return the merged stats from all hooks 

367 

368 Returns: 

369 dict: Merged stats from all hooks 

370 """ 

371 stats = {} 

372 for hook in self.hooks: 

373 stats = {**stats, **hook.return_stats()} 

374 return stats 

375 

376 

377class ParaDiagController(Controller): 

378 

379 def __init__( 

380 self, 

381 controller_params: Dict[str, Any], 

382 description: Dict[str, Any], 

383 n_steps: int, 

384 useMPI: Optional[bool] = None, 

385 ) -> None: 

386 """ 

387 Initialization routine for ParaDiag controllers 

388 

389 Args: 

390 num_procs: number of parallel time steps (still serial, though), can be 1 

391 controller_params: parameter set for the controller and the steps 

392 description: all the parameters to set up the rest (levels, problems, transfer, ...) 

393 n_steps (int): Number of parallel steps 

394 alpha (float): alpha parameter for ParaDiag 

395 """ 

396 from pySDC.implementations.sweeper_classes.ParaDiagSweepers import QDiagonalization 

397 

398 if QDiagonalization in description['sweeper_class'].__mro__: 

399 description['sweeper_params']['ignore_ic'] = True 

400 description['sweeper_params']['update_f_evals'] = False 

401 else: 

402 logging.getLogger('controller').warning( 

403 f'Warning: Your sweeper class {description["sweeper_class"]} is not derived from {QDiagonalization}. You probably want to use another sweeper class.' 

404 ) 

405 

406 if controller_params.get('all_to_done', False): 

407 raise NotImplementedError('ParaDiag only implemented with option `all_to_done=True`') 

408 if 'alpha' not in controller_params.keys(): 

409 from pySDC.core.errors import ParameterError 

410 

411 raise ParameterError('Please supply alpha as a parameter to the ParaDiag controller!') 

412 controller_params['average_jacobian'] = controller_params.get('average_jacobian', True) 

413 

414 controller_params['all_to_done'] = True 

415 super().__init__(controller_params=controller_params, description=description, useMPI=useMPI) 

416 

417 self.n_steps: int = n_steps 

418 

419 def FFT_in_time(self, quantity: Any) -> None: 

420 """ 

421 Compute weighted forward FFT in time. The weighting is determined by the alpha parameter in ParaDiag 

422 

423 Note: The implementation via matrix-vector multiplication may be inefficient and less stable compared to an FFT 

424 with transposes! 

425 """ 

426 if not hasattr(self, '__FFT_matrix'): 

427 from pySDC.helpers.ParaDiagHelper import get_weighted_FFT_matrix 

428 

429 self.__FFT_matrix = get_weighted_FFT_matrix(self.n_steps, self.params.alpha) 

430 

431 self.apply_matrix(self.__FFT_matrix, quantity) 

432 

433 def iFFT_in_time(self, quantity: Any) -> None: 

434 """ 

435 Compute weighted backward FFT in time. The weighting is determined by the alpha parameter in ParaDiag 

436 """ 

437 if not hasattr(self, '__iFFT_matrix'): 

438 from pySDC.helpers.ParaDiagHelper import get_weighted_iFFT_matrix 

439 

440 self.__iFFT_matrix = get_weighted_iFFT_matrix(self.n_steps, self.params.alpha) 

441 

442 self.apply_matrix(self.__iFFT_matrix, quantity)