Coverage for pySDC/projects/RDC/vanderpol_error_test.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import matplotlib
3matplotlib.use('Agg')
5import matplotlib.pylab as plt
7import numpy as np
8import pickle
9import os
11from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol
12from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
13from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
15from pySDC.projects.RDC.equidistant_RDC import Equidistant_RDC
18def compute_RDC_errors():
19 """
20 Van der Pol's oscillator with RDC
21 """
23 # initialize level parameters
24 level_params = dict()
25 level_params['restol'] = 0
26 level_params['dt'] = 10.0 / 40.0
28 # initialize sweeper parameters
29 sweeper_params = dict()
30 sweeper_params['collocation_class'] = Equidistant_RDC
31 sweeper_params['num_nodes'] = 41
32 sweeper_params['QI'] = 'IE'
34 # initialize problem parameters
35 problem_params = dict()
36 problem_params['newton_tol'] = 1e-14
37 problem_params['newton_maxiter'] = 50
38 problem_params['mu'] = 10
39 problem_params['u0'] = (2.0, 0)
41 # initialize step parameters
42 step_params = dict()
43 step_params['maxiter'] = None
45 # initialize controller parameters
46 controller_params = dict()
47 controller_params['logger_level'] = 30
49 # Fill description dictionary for easy hierarchy creation
50 description = dict()
51 description['problem_class'] = vanderpol
52 description['problem_params'] = problem_params
53 description['sweeper_class'] = generic_implicit
54 description['sweeper_params'] = sweeper_params
55 description['level_params'] = level_params
56 description['step_params'] = step_params
58 # instantiate the controller
59 controller_rdc = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
61 # set time parameters
62 t0 = 0.0
63 Tend = 10.0
65 # get initial values on finest level
66 P = controller_rdc.MS[0].levels[0].prob
67 uinit = P.u_exact(t0)
69 ref_sol = np.load('data/vdp_ref.npy')
71 maxiter_list = range(1, 11)
72 results = dict()
73 results['maxiter_list'] = maxiter_list
75 for maxiter in maxiter_list:
76 # ugly, but much faster than re-initializing the controller over and over again
77 controller_rdc.MS[0].params.maxiter = maxiter
79 # call main function to get things done...
80 uend_rdc, stats_rdc = controller_rdc.run(u0=uinit, t0=t0, Tend=Tend)
82 err = np.linalg.norm(uend_rdc - ref_sol, np.inf) / np.linalg.norm(ref_sol, np.inf)
83 print('Maxiter = %2i -- Error: %8.4e' % (controller_rdc.MS[0].params.maxiter, err))
84 results[maxiter] = err
86 fname = 'data/vdp_results.pkl'
87 file = open(fname, 'wb')
88 pickle.dump(results, file)
89 file.close()
91 assert os.path.isfile(fname), 'ERROR: pickle did not create file'
94def plot_RDC_results(cwd=''):
95 """
96 Routine to visualize the errors
98 Args:
99 cwd (string): current working directory
100 """
102 file = open(cwd + 'data/vdp_results.pkl', 'rb')
103 results = pickle.load(file, encoding='latin-1')
104 file.close()
106 # retrieve the list of nvars from results
107 assert 'maxiter_list' in results, 'ERROR: expecting the list of maxiters in the results dictionary'
108 maxiter_list = sorted(results['maxiter_list'])
110 # Set up plotting parameters
111 params = {
112 'legend.fontsize': 20,
113 'figure.figsize': (12, 8),
114 'axes.labelsize': 20,
115 'axes.titlesize': 20,
116 'xtick.labelsize': 16,
117 'ytick.labelsize': 16,
118 'lines.linewidth': 3,
119 }
120 plt.rcParams.update(params)
122 # create new figure
123 plt.figure()
124 # take x-axis limits from nvars_list + some spacning left and right
125 plt.xlim([min(maxiter_list) - 1, max(maxiter_list) + 1])
126 plt.xlabel('maxiter')
127 plt.ylabel('rel. error')
128 plt.grid()
130 min_err = 1e99
131 max_err = 0e00
132 err_list = []
133 # loop over nvars, get errors and find min/max error for y-axis limits
134 for maxiter in maxiter_list:
135 err = results[maxiter]
136 min_err = min(err, min_err)
137 max_err = max(err, max_err)
138 err_list.append(err)
139 plt.semilogy(maxiter_list, err_list, ls='-', marker='o', markersize=10, label='RDC')
141 # adjust y-axis limits, add legend
142 plt.ylim([min_err / 10, max_err * 10])
143 plt.legend(loc=1, ncol=1, numpoints=1)
145 # plt.show()
147 # save plot as PNG, beautify
148 fname = 'data/RDC_errors_vdp.png'
149 plt.savefig(fname, bbox_inches='tight')
151 assert os.path.isfile(fname), 'ERROR: plot was not created'
153 return None
156if __name__ == "__main__":
157 compute_RDC_errors()
158 plot_RDC_results()