Coverage for pySDC/projects/SDC_showdown/SDC_timing_Fisher.py: 100%
109 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import os
2import pickle
4import numpy as np
6import pySDC.helpers.plot_helper as plt_helper
7from pySDC.helpers.stats_helper import get_sorted
9from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
10from pySDC.implementations.problem_classes.GeneralizedFisher_1D_PETSc import (
11 petsc_fisher_multiimplicit,
12 petsc_fisher_fullyimplicit,
13 petsc_fisher_semiimplicit,
14)
15from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
16from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
17from pySDC.implementations.sweeper_classes.multi_implicit import multi_implicit
20def setup_parameters():
21 """
22 Helper routine to fill in all relevant parameters
24 Note that this file will be used for all versions of SDC, containing more than necessary for each individual run
26 Returns:
27 description (dict)
28 controller_params (dict)
29 """
31 # initialize level parameters
32 level_params = dict()
33 level_params['restol'] = 1e-06
34 level_params['dt'] = 0.25
35 level_params['nsweeps'] = [1]
37 # initialize sweeper parameters
38 sweeper_params = dict()
39 sweeper_params['quad_type'] = 'RADAU-RIGHT'
40 sweeper_params['num_nodes'] = [3]
41 sweeper_params['Q1'] = ['LU']
42 sweeper_params['Q2'] = ['LU']
43 sweeper_params['QI'] = ['LU']
44 sweeper_params['initial_guess'] = 'zero'
46 # initialize problem parameters
47 problem_params = dict()
48 problem_params['nu'] = 1
49 problem_params['nvars'] = 2049
50 problem_params['lambda0'] = 2.0
51 problem_params['interval'] = (-50, 50)
52 problem_params['nlsol_tol'] = 1e-10
53 problem_params['nlsol_maxiter'] = 100
54 problem_params['lsol_tol'] = 1e-10
55 problem_params['lsol_maxiter'] = 100
57 # initialize step parameters
58 step_params = dict()
59 step_params['maxiter'] = 50
61 # initialize space transfer parameters
62 # space_transfer_params = dict()
63 # space_transfer_params['finter'] = True
65 # initialize controller parameters
66 controller_params = dict()
67 controller_params['logger_level'] = 30
69 # fill description dictionary for easy step instantiation
70 description = dict()
71 description['problem_class'] = None # pass problem class
72 description['problem_params'] = problem_params # pass problem parameters
73 description['sweeper_class'] = None # pass sweeper (see part B)
74 description['sweeper_params'] = sweeper_params # pass sweeper parameters
75 description['level_params'] = level_params # pass level parameters
76 description['step_params'] = step_params # pass step parameters
77 # description['space_transfer_class'] = mesh_to_mesh_petsc_dmda # pass spatial transfer class
78 # description['space_transfer_params'] = space_transfer_params # pass paramters for spatial transfer
80 return description, controller_params
83def run_SDC_variant(variant=None, inexact=False):
84 """
85 Routine to run particular SDC variant
87 Args:
88 variant (str): string describing the variant
89 inexact (bool): flag to use inexact nonlinear solve (or nor)
91 Returns:
92 timing (float)
93 niter (float)
94 """
96 # load (incomplete) default parameters
97 description, controller_params = setup_parameters()
99 # add stuff based on variant
100 if variant == 'fully-implicit':
101 description['problem_class'] = petsc_fisher_fullyimplicit
102 description['sweeper_class'] = generic_implicit
103 elif variant == 'semi-implicit':
104 description['problem_class'] = petsc_fisher_semiimplicit
105 description['sweeper_class'] = imex_1st_order
106 elif variant == 'multi-implicit':
107 description['problem_class'] = petsc_fisher_multiimplicit
108 description['sweeper_class'] = multi_implicit
109 else:
110 raise NotImplementedError('Wrong variant specified, got %s' % variant)
112 if inexact:
113 description['problem_params']['nlsol_maxiter'] = 1
114 out = 'Working on inexact %s variant...' % variant
115 else:
116 out = 'Working on exact %s variant...' % variant
117 print(out)
119 # set time parameters
120 t0 = 0.0
121 Tend = 1.0
123 # instantiate controller
124 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
126 # get initial values on finest level
127 P = controller.MS[0].levels[0].prob
128 uinit = P.u_exact(t0)
130 # call main function to get things done...
131 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
133 # compute exact solution and compare
134 uex = P.u_exact(Tend)
135 err = abs(uex - uend)
137 # filter statistics by variant (number of iterations)
138 iter_counts = get_sorted(stats, type='niter', sortby='time')
140 # compute and print statistics
141 niters = np.array([item[1] for item in iter_counts])
142 out = ' Mean number of iterations: %4.2f' % np.mean(niters)
143 print(out)
144 out = ' Range of values for number of iterations: %2i ' % np.ptp(niters)
145 print(out)
146 out = ' Position of max/min number of iterations: %2i -- %2i' % (int(np.argmax(niters)), int(np.argmin(niters)))
147 print(out)
148 out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters)))
149 print(out)
151 print('Iteration count (nonlinear/linear): %i / %i' % (P.snes_itercount, P.ksp_itercount))
152 print(
153 'Mean Iteration count per call: %4.2f / %4.2f'
154 % (P.snes_itercount / max(P.snes_ncalls, 1), P.ksp_itercount / max(P.ksp_ncalls, 1))
155 )
157 timing = get_sorted(stats, type='timing_run', sortby='time')
159 print('Time to solution: %6.4f sec.' % timing[0][1])
160 print('Error vs. PDE solution: %6.4e' % err)
161 print()
163 assert err < 9.2e-05, 'ERROR: variant %s did not match error tolerance, got %s' % (variant, err)
164 assert np.mean(niters) <= 10, 'ERROR: number of iterations is too high, got %s' % np.mean(niters)
166 return timing[0][1], np.mean(niters)
169def show_results(fname):
170 """
171 Plotting routine
173 Args:
174 fname: file name to read in and name plots
175 """
177 file = open(fname + '.pkl', 'rb')
178 results = pickle.load(file)
179 file.close()
181 plt_helper.mpl.style.use('classic')
182 plt_helper.setup_mpl()
184 plt_helper.newfig(textwidth=238.96, scale=1.0)
186 xcoords = list(range(len(results)))
187 sorted_data = sorted([(key, results[key][0]) for key in results], reverse=True, key=lambda tup: tup[1])
188 heights = [item[1] for item in sorted_data]
189 keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n') for item in sorted_data]
191 plt_helper.plt.bar(xcoords, heights, align='center')
193 plt_helper.plt.xticks(xcoords, keys, rotation=90)
194 plt_helper.plt.ylabel('time (sec)')
196 # save plot, beautify
197 plt_helper.savefig(fname)
199 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file'
200 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file'
201 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file'
203 return None
206def main(cwd=''):
207 """
208 Main driver
210 Args:
211 cwd (str): current working directory (need this for testing)
212 """
214 # Loop over variants, exact and inexact solves
215 results = {}
216 for variant in ['fully-implicit', 'multi-implicit', 'semi-implicit']:
217 results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False)
218 results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True)
220 # dump result
221 fname = cwd + 'data/timings_SDC_variants_Fisher'
222 file = open(fname + '.pkl', 'wb')
223 pickle.dump(results, file)
224 file.close()
225 assert os.path.isfile(fname + '.pkl'), 'ERROR: pickle did not create file'
227 # visualize
228 show_results(fname)
231if __name__ == "__main__":
232 main()