Coverage for pySDC/tutorial/step_8/HookClass_error_output.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1from pySDC.core.hooks import Hooks 

2from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

3from pySDC.implementations.problem_classes.Auzinger_implicit import auzinger 

4 

5 

6class error_output(Hooks): 

7 """ 

8 Hook class to add output of error 

9 """ 

10 

11 def __init__(self): 

12 super(error_output, self).__init__() 

13 self.uex = None 

14 

15 def pre_step(self, step, level_number): 

16 """ 

17 Default routine called before each step 

18 Args: 

19 step: the current step 

20 level_number: the current level number 

21 """ 

22 super(error_output, self).pre_step(step, level_number) 

23 

24 L = step.levels[level_number] 

25 

26 # This is a bit black magic: we are going to run pySDC within the hook to check the error against the "exact" 

27 # solution of the collocation problem 

28 description = step.params.description 

29 description['level_params']['restol'] = 1e-14 

30 if type(L.prob) != auzinger: 

31 description['problem_params']['solver_type'] = 'direct' 

32 

33 controller_params = step.params.controller_params 

34 del controller_params['hook_class'] # get rid of the hook, otherwise this will be an endless recursion.. 

35 controller_params['logger_level'] = 90 

36 controller_params['convergence_controllers'] = {} 

37 

38 controller = controller_nonMPI(num_procs=1, description=description, controller_params=controller_params) 

39 self.uex, _ = controller.run(u0=L.u[0], t0=L.time, Tend=L.time + L.dt) 

40 

41 def post_step(self, step, level_number): 

42 """ 

43 Default routine called after each step 

44 Args: 

45 step: the current step 

46 level_number: the current level number 

47 """ 

48 

49 super(error_output, self).post_step(step, level_number) 

50 

51 # some abbreviations 

52 L = step.levels[level_number] 

53 P = L.prob 

54 

55 L.sweep.compute_end_point() 

56 

57 # compute and save errors 

58 upde = P.u_exact(step.time + step.dt) 

59 pde_err = abs(upde - L.uend) 

60 coll_err = abs(self.uex - L.uend) 

61 

62 self.add_to_stats( 

63 process=step.status.slot, 

64 time=L.time + L.dt, 

65 level=L.level_index, 

66 iter=step.status.iter, 

67 sweep=L.status.sweep, 

68 type='PDE_error_after_step', 

69 value=pde_err, 

70 ) 

71 self.add_to_stats( 

72 process=step.status.slot, 

73 time=L.time + L.dt, 

74 level=L.level_index, 

75 iter=step.status.iter, 

76 sweep=L.status.sweep, 

77 type='coll_error_after_step', 

78 value=coll_err, 

79 )