Coverage for pySDC/projects/Hamiltonian/solar_system.py: 96%
157 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 09:13 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 09:13 +0000
1import os
2from collections import defaultdict
3from mpl_toolkits.mplot3d import Axes3D
5import dill
6import numpy as np
8import pySDC.helpers.plot_helper as plt_helper
9from pySDC.helpers.stats_helper import get_sorted, filter_stats
11from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
12from pySDC.implementations.problem_classes.FullSolarSystem import full_solar_system
13from pySDC.implementations.problem_classes.OuterSolarSystem import outer_solar_system
14from pySDC.implementations.sweeper_classes.verlet import verlet
15from pySDC.implementations.transfer_classes.TransferParticles_NoCoarse import particles_to_particles
16from pySDC.projects.Hamiltonian.hamiltonian_output import hamiltonian_output
19def setup_outer_solar_system():
20 """
21 Helper routine for setting up everything for the outer solar system problem
23 Returns:
24 description (dict): description of the controller
25 controller_params (dict): controller parameters
26 """
28 # initialize level parameters
29 level_params = dict()
30 level_params['restol'] = 1e-10
31 level_params['dt'] = 100.0
33 # initialize sweeper parameters
34 sweeper_params = dict()
35 sweeper_params['quad_type'] = 'LOBATTO'
36 sweeper_params['num_nodes'] = [5, 3]
37 sweeper_params['initial_guess'] = 'spread'
39 # initialize problem parameters for the Penning trap
40 problem_params = dict()
41 problem_params['sun_only'] = [False, True]
43 # initialize step parameters
44 step_params = dict()
45 step_params['maxiter'] = 50
47 # initialize controller parameters
48 controller_params = dict()
49 controller_params['hook_class'] = hamiltonian_output # specialized hook class for more statistics and output
50 controller_params['logger_level'] = 30
52 # Fill description dictionary for easy hierarchy creation
53 description = dict()
54 description['problem_class'] = outer_solar_system
55 description['problem_params'] = problem_params
56 description['sweeper_class'] = verlet
57 description['sweeper_params'] = sweeper_params
58 description['level_params'] = level_params
59 description['step_params'] = step_params
60 description['space_transfer_class'] = particles_to_particles
62 return description, controller_params
65def setup_full_solar_system():
66 """
67 Helper routine for setting up everything for the full solar system problem
69 Returns:
70 description (dict): description of the controller
71 controller_params (dict): controller parameters
72 """
74 # initialize level parameters
75 level_params = dict()
76 level_params['restol'] = 1e-10
77 level_params['dt'] = 10.0
79 # initialize sweeper parameters
80 sweeper_params = dict()
81 sweeper_params['quad_type'] = 'LOBATTO'
82 sweeper_params['num_nodes'] = [5, 3]
83 sweeper_params['initial_guess'] = 'spread'
85 # initialize problem parameters for the Penning trap
86 problem_params = dict()
87 problem_params['sun_only'] = [False, True]
89 # initialize step parameters
90 step_params = dict()
91 step_params['maxiter'] = 50
93 # initialize controller parameters
94 controller_params = dict()
95 controller_params['hook_class'] = hamiltonian_output # specialized hook class for more statistics and output
96 controller_params['logger_level'] = 30
98 # Fill description dictionary for easy hierarchy creation
99 description = dict()
100 description['problem_class'] = full_solar_system
101 description['problem_params'] = problem_params
102 description['sweeper_class'] = verlet
103 description['sweeper_params'] = sweeper_params
104 description['level_params'] = level_params
105 description['step_params'] = step_params
106 description['space_transfer_class'] = particles_to_particles
108 return description, controller_params
111def run_simulation(prob=None):
112 """
113 Routine to run the simulation of a second order problem
115 Args:
116 prob (str): name of the problem
118 """
120 if prob == 'outer_solar_system':
121 description, controller_params = setup_outer_solar_system()
122 # set time parameters
123 t0 = 0.0
124 Tend = 10000.0
125 num_procs = 100
126 maxmeaniter = 6.0
127 elif prob == 'full_solar_system':
128 description, controller_params = setup_full_solar_system()
129 # set time parameters
130 t0 = 0.0
131 Tend = 1000.0
132 num_procs = 100
133 maxmeaniter = 19.0
134 else:
135 raise NotImplementedError('Problem type not implemented, got %s' % prob)
137 f = open('data/' + prob + '_out.txt', 'w')
138 out = 'Running ' + prob + ' problem with %s processors...' % num_procs
139 f.write(out + '\n')
140 print(out)
142 # instantiate the controller
143 controller = controller_nonMPI(num_procs=num_procs, controller_params=controller_params, description=description)
145 # get initial values on finest level
146 P = controller.MS[0].levels[0].prob
147 uinit = P.u_exact(t=t0)
149 # call main function to get things done...
150 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
152 # filter statistics by type (number of iterations)
153 iter_counts = get_sorted(stats, type='niter', sortby='time')
155 # compute and print statistics
156 # for item in iter_counts:
157 # out = 'Number of iterations for time %4.2f: %2i' % item
158 # f.write(out)
159 # print(out)
161 niters = np.array([item[1] for item in iter_counts])
162 out = ' Mean number of iterations: %4.2f' % np.mean(niters)
163 f.write(out + '\n')
164 print(out)
165 out = ' Range of values for number of iterations: %2i ' % np.ptp(niters)
166 f.write(out + '\n')
167 print(out)
168 out = ' Position of max/min number of iterations: %2i -- %2i' % (int(np.argmax(niters)), int(np.argmin(niters)))
169 f.write(out + '\n')
170 print(out)
171 out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters)))
172 f.write(out + '\n')
173 print(out)
174 f.close()
176 assert np.mean(niters) <= maxmeaniter, 'Mean number of iterations is too high, got %s' % np.mean(niters)
178 fname = 'data/' + prob + '.dat'
179 f = open(fname, 'wb')
180 dill.dump(stats, f)
181 f.close()
183 assert os.path.isfile(fname), 'Run for %s did not create stats file' % prob
186def show_results(prob=None, cwd=''):
187 """
188 Helper function to plot the error of the Hamiltonian
190 Args:
191 prob (str): name of the problem
192 cwd (str): current working directory
193 """
195 # read in the dill data
196 f = open(cwd + 'data/' + prob + '.dat', 'rb')
197 stats = dill.load(f)
198 f.close()
200 plt_helper.mpl.style.use('classic')
201 plt_helper.setup_mpl()
203 # extract error in hamiltonian and prepare for plotting
204 extract_stats = filter_stats(stats, type='err_hamiltonian')
205 result = defaultdict(list)
206 for k, v in extract_stats.items():
207 result[k.iter].append((k.time, v))
208 for k, _ in result.items():
209 result[k] = sorted(result[k], key=lambda x: x[0])
211 plt_helper.newfig(textwidth=238.96, scale=0.89)
213 # Rearrange data for easy plotting
214 err_ham = 1
215 for k, v in result.items():
216 time = [item[0] for item in v]
217 ham = [item[1] for item in v]
218 err_ham = ham[-1]
219 plt_helper.plt.semilogy(time, ham, '-', lw=1, label='Iter ' + str(k))
220 assert err_ham < 2.4e-14, 'Error in the Hamiltonian is too large for %s, got %s' % (prob, err_ham)
222 plt_helper.plt.xlabel('Time')
223 plt_helper.plt.ylabel('Error in Hamiltonian')
224 plt_helper.plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
226 fname = 'data/' + prob + '_hamiltonian'
227 plt_helper.savefig(fname)
229 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file'
230 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file'
231 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file'
233 # extract positions and prepare for plotting
234 result = get_sorted(stats, type='position', sortby='time')
236 fig = plt_helper.plt.figure()
237 ax = fig.add_subplot(111, projection='3d')
239 # Rearrange data for easy plotting
240 nparts = len(result[1][1][0])
241 ndim = len(result[1][1])
242 nsteps = len(result)
243 pos = np.zeros((nparts, ndim, nsteps))
245 for idx, item in enumerate(result):
246 for n in range(nparts):
247 for m in range(ndim):
248 pos[n, m, idx] = item[1][m][n]
250 for n in range(nparts):
251 if ndim == 2:
252 ax.plot(pos[n, 0, :], pos[n, 1, :])
253 elif ndim == 3:
254 ax.plot(pos[n, 0, :], pos[n, 1, :], pos[n, 2, :])
255 else:
256 raise NotImplementedError('Wrong number of dimensions for plotting, got %s' % ndim)
258 fname = 'data/' + prob + '_positions'
259 plt_helper.savefig(fname)
261 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file'
262 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file'
263 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file'
266def main():
267 prob = 'outer_solar_system'
268 run_simulation(prob)
269 show_results(prob)
270 prob = 'full_solar_system'
271 run_simulation(prob)
272 show_results(prob)
275if __name__ == "__main__":
276 main()