Coverage for pySDC/projects/DAE/run/run_convergence_test.py: 100%
60 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import numpy as np
2import statistics
3import pickle
5from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
6from pySDC.projects.DAE.problems.simpleDAE import SimpleDAE
7from pySDC.projects.DAE.sweepers.fullyImplicitDAE import FullyImplicitDAE
8from pySDC.projects.DAE.misc.hooksDAE import LogGlobalErrorPostStepDifferentialVariable
9from pySDC.helpers.stats_helper import get_sorted
10from pySDC.helpers.stats_helper import filter_stats
13def setup():
14 """
15 Routine to initialise convergence test parameters
16 """
17 # initialize level parameters
18 level_params = dict()
19 level_params['restol'] = 1e-10
21 # This comes as read-in for the sweeper class
22 sweeper_params = dict()
23 sweeper_params['quad_type'] = 'RADAU-RIGHT'
25 # This comes as read-in for the problem class
26 problem_params = dict()
27 problem_params['newton_tol'] = 1e-3 # tollerance for implicit solver
29 # This comes as read-in for the step class
30 step_params = dict()
31 step_params['maxiter'] = 30
33 # initialize controller parameters
34 controller_params = dict()
35 controller_params['logger_level'] = 30
36 controller_params['hook_class'] = LogGlobalErrorPostStepDifferentialVariable
38 # Fill description dictionary for easy hierarchy creation
39 description = dict()
40 description['problem_class'] = SimpleDAE
41 description['problem_params'] = problem_params
42 description['sweeper_class'] = FullyImplicitDAE
43 description['sweeper_params'] = sweeper_params
44 description['level_params'] = level_params
45 description['step_params'] = step_params
47 # set simulation parameters
48 num_samples = 2
49 run_params = dict()
50 run_params['t0'] = 0.0
51 run_params['tend'] = 0.1
52 run_params['dt_list'] = np.logspace(-2, -3, num=num_samples)
53 run_params['qd_list'] = ['IE', 'LU']
54 run_params['num_nodes_list'] = [3]
56 return description, controller_params, run_params
59def run(description, controller_params, run_params):
60 """
61 Routine to run simulation
62 """
63 conv_data = dict()
65 for qd_type in run_params['qd_list']:
66 description['sweeper_params']['QI'] = qd_type
67 conv_data[qd_type] = dict()
69 for num_nodes in run_params['num_nodes_list']:
70 description['sweeper_params']['num_nodes'] = num_nodes
71 conv_data[qd_type][num_nodes] = dict()
72 conv_data[qd_type][num_nodes]['error'] = np.zeros_like(run_params['dt_list'])
73 conv_data[qd_type][num_nodes]['niter'] = np.zeros_like(run_params['dt_list'], dtype='int')
74 conv_data[qd_type][num_nodes]['dt'] = run_params['dt_list']
76 for j, dt in enumerate(run_params['dt_list']):
77 print('Working on Qdelta=%s -- num. nodes=%i -- dt=%f' % (qd_type, num_nodes, dt))
78 description['level_params']['dt'] = dt
80 # instantiate the controller
81 controller = controller_nonMPI(
82 num_procs=1, controller_params=controller_params, description=description
83 )
84 # get initial values
85 P = controller.MS[0].levels[0].prob
86 uinit = P.u_exact(run_params['t0'])
88 # call main function to get things done...
89 uend, stats = controller.run(u0=uinit, t0=run_params['t0'], Tend=run_params['tend'])
91 # compute exact solution and compare
92 err = get_sorted(stats, type='e_global_differential_post_step', sortby='time')
93 niter = filter_stats(stats, type='niter')
95 conv_data[qd_type][num_nodes]['error'][j] = np.linalg.norm([err[j][1] for j in range(len(err))], np.inf)
96 conv_data[qd_type][num_nodes]['niter'][j] = round(statistics.mean(niter.values()))
97 print("Error is", conv_data[qd_type][num_nodes]['error'][j])
98 return conv_data
101if __name__ == "__main__":
102 """
103 Routine to run convergence tests for the fully implicit solver using specified example with various preconditioners, time step sizes and collocation node counts
104 Error data is stored in a dictionary and then pickled for use with the loglog_plot.py routine
105 """
106 description, controller_params, run_params = setup()
107 conv_data = run(description, controller_params, run_params)
108 pickle.dump(conv_data, open("data/dae_conv_data.p", 'wb'))
109 print("Done")