Coverage for pySDC/projects/DAE/run/run_convergence_test.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import numpy as np 

2import statistics 

3import pickle 

4 

5from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

6from pySDC.projects.DAE.problems.simpleDAE import SimpleDAE 

7from pySDC.projects.DAE.sweepers.fullyImplicitDAE import FullyImplicitDAE 

8from pySDC.projects.DAE.misc.hooksDAE import LogGlobalErrorPostStepDifferentialVariable 

9from pySDC.helpers.stats_helper import get_sorted 

10from pySDC.helpers.stats_helper import filter_stats 

11 

12 

13def setup(): 

14 """ 

15 Routine to initialise convergence test parameters 

16 """ 

17 # initialize level parameters 

18 level_params = dict() 

19 level_params['restol'] = 1e-10 

20 

21 # This comes as read-in for the sweeper class 

22 sweeper_params = dict() 

23 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

24 

25 # This comes as read-in for the problem class 

26 problem_params = dict() 

27 problem_params['newton_tol'] = 1e-3 # tollerance for implicit solver 

28 

29 # This comes as read-in for the step class 

30 step_params = dict() 

31 step_params['maxiter'] = 30 

32 

33 # initialize controller parameters 

34 controller_params = dict() 

35 controller_params['logger_level'] = 30 

36 controller_params['hook_class'] = LogGlobalErrorPostStepDifferentialVariable 

37 

38 # Fill description dictionary for easy hierarchy creation 

39 description = dict() 

40 description['problem_class'] = SimpleDAE 

41 description['problem_params'] = problem_params 

42 description['sweeper_class'] = FullyImplicitDAE 

43 description['sweeper_params'] = sweeper_params 

44 description['level_params'] = level_params 

45 description['step_params'] = step_params 

46 

47 # set simulation parameters 

48 num_samples = 2 

49 run_params = dict() 

50 run_params['t0'] = 0.0 

51 run_params['tend'] = 0.1 

52 run_params['dt_list'] = np.logspace(-2, -3, num=num_samples) 

53 run_params['qd_list'] = ['IE', 'LU'] 

54 run_params['num_nodes_list'] = [3] 

55 

56 return description, controller_params, run_params 

57 

58 

59def run(description, controller_params, run_params): 

60 """ 

61 Routine to run simulation 

62 """ 

63 conv_data = dict() 

64 

65 for qd_type in run_params['qd_list']: 

66 description['sweeper_params']['QI'] = qd_type 

67 conv_data[qd_type] = dict() 

68 

69 for num_nodes in run_params['num_nodes_list']: 

70 description['sweeper_params']['num_nodes'] = num_nodes 

71 conv_data[qd_type][num_nodes] = dict() 

72 conv_data[qd_type][num_nodes]['error'] = np.zeros_like(run_params['dt_list']) 

73 conv_data[qd_type][num_nodes]['niter'] = np.zeros_like(run_params['dt_list'], dtype='int') 

74 conv_data[qd_type][num_nodes]['dt'] = run_params['dt_list'] 

75 

76 for j, dt in enumerate(run_params['dt_list']): 

77 print('Working on Qdelta=%s -- num. nodes=%i -- dt=%f' % (qd_type, num_nodes, dt)) 

78 description['level_params']['dt'] = dt 

79 

80 # instantiate the controller 

81 controller = controller_nonMPI( 

82 num_procs=1, controller_params=controller_params, description=description 

83 ) 

84 # get initial values 

85 P = controller.MS[0].levels[0].prob 

86 uinit = P.u_exact(run_params['t0']) 

87 

88 # call main function to get things done... 

89 uend, stats = controller.run(u0=uinit, t0=run_params['t0'], Tend=run_params['tend']) 

90 

91 # compute exact solution and compare 

92 err = get_sorted(stats, type='e_global_differential_post_step', sortby='time') 

93 niter = filter_stats(stats, type='niter') 

94 

95 conv_data[qd_type][num_nodes]['error'][j] = np.linalg.norm([err[j][1] for j in range(len(err))], np.inf) 

96 conv_data[qd_type][num_nodes]['niter'][j] = round(statistics.mean(niter.values())) 

97 print("Error is", conv_data[qd_type][num_nodes]['error'][j]) 

98 return conv_data 

99 

100 

101if __name__ == "__main__": 

102 """ 

103 Routine to run convergence tests for the fully implicit solver using specified example with various preconditioners, time step sizes and collocation node counts 

104 Error data is stored in a dictionary and then pickled for use with the loglog_plot.py routine 

105 """ 

106 description, controller_params, run_params = setup() 

107 conv_data = run(description, controller_params, run_params) 

108 pickle.dump(conv_data, open("data/dae_conv_data.p", 'wb')) 

109 print("Done")