Coverage for pySDC/projects/SDC_showdown/SDC_timing_Fisher.py: 100%

109 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import os 

2import pickle 

3 

4import numpy as np 

5 

6import pySDC.helpers.plot_helper as plt_helper 

7from pySDC.helpers.stats_helper import get_sorted 

8 

9from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

10from pySDC.implementations.problem_classes.GeneralizedFisher_1D_PETSc import ( 

11 petsc_fisher_multiimplicit, 

12 petsc_fisher_fullyimplicit, 

13 petsc_fisher_semiimplicit, 

14) 

15from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

16from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

17from pySDC.implementations.sweeper_classes.multi_implicit import multi_implicit 

18 

19 

20def setup_parameters(): 

21 """ 

22 Helper routine to fill in all relevant parameters 

23 

24 Note that this file will be used for all versions of SDC, containing more than necessary for each individual run 

25 

26 Returns: 

27 description (dict) 

28 controller_params (dict) 

29 """ 

30 

31 # initialize level parameters 

32 level_params = dict() 

33 level_params['restol'] = 1e-06 

34 level_params['dt'] = 0.25 

35 level_params['nsweeps'] = [1] 

36 

37 # initialize sweeper parameters 

38 sweeper_params = dict() 

39 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

40 sweeper_params['num_nodes'] = [3] 

41 sweeper_params['Q1'] = ['LU'] 

42 sweeper_params['Q2'] = ['LU'] 

43 sweeper_params['QI'] = ['LU'] 

44 sweeper_params['initial_guess'] = 'zero' 

45 

46 # initialize problem parameters 

47 problem_params = dict() 

48 problem_params['nu'] = 1 

49 problem_params['nvars'] = 2049 

50 problem_params['lambda0'] = 2.0 

51 problem_params['interval'] = (-50, 50) 

52 problem_params['nlsol_tol'] = 1e-10 

53 problem_params['nlsol_maxiter'] = 100 

54 problem_params['lsol_tol'] = 1e-10 

55 problem_params['lsol_maxiter'] = 100 

56 

57 # initialize step parameters 

58 step_params = dict() 

59 step_params['maxiter'] = 50 

60 

61 # initialize space transfer parameters 

62 # space_transfer_params = dict() 

63 # space_transfer_params['finter'] = True 

64 

65 # initialize controller parameters 

66 controller_params = dict() 

67 controller_params['logger_level'] = 30 

68 

69 # fill description dictionary for easy step instantiation 

70 description = dict() 

71 description['problem_class'] = None # pass problem class 

72 description['problem_params'] = problem_params # pass problem parameters 

73 description['sweeper_class'] = None # pass sweeper (see part B) 

74 description['sweeper_params'] = sweeper_params # pass sweeper parameters 

75 description['level_params'] = level_params # pass level parameters 

76 description['step_params'] = step_params # pass step parameters 

77 # description['space_transfer_class'] = mesh_to_mesh_petsc_dmda # pass spatial transfer class 

78 # description['space_transfer_params'] = space_transfer_params # pass paramters for spatial transfer 

79 

80 return description, controller_params 

81 

82 

83def run_SDC_variant(variant=None, inexact=False): 

84 """ 

85 Routine to run particular SDC variant 

86 

87 Args: 

88 variant (str): string describing the variant 

89 inexact (bool): flag to use inexact nonlinear solve (or nor) 

90 

91 Returns: 

92 timing (float) 

93 niter (float) 

94 """ 

95 

96 # load (incomplete) default parameters 

97 description, controller_params = setup_parameters() 

98 

99 # add stuff based on variant 

100 if variant == 'fully-implicit': 

101 description['problem_class'] = petsc_fisher_fullyimplicit 

102 description['sweeper_class'] = generic_implicit 

103 elif variant == 'semi-implicit': 

104 description['problem_class'] = petsc_fisher_semiimplicit 

105 description['sweeper_class'] = imex_1st_order 

106 elif variant == 'multi-implicit': 

107 description['problem_class'] = petsc_fisher_multiimplicit 

108 description['sweeper_class'] = multi_implicit 

109 else: 

110 raise NotImplementedError('Wrong variant specified, got %s' % variant) 

111 

112 if inexact: 

113 description['problem_params']['nlsol_maxiter'] = 1 

114 out = 'Working on inexact %s variant...' % variant 

115 else: 

116 out = 'Working on exact %s variant...' % variant 

117 print(out) 

118 

119 # set time parameters 

120 t0 = 0.0 

121 Tend = 1.0 

122 

123 # instantiate controller 

124 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) 

125 

126 # get initial values on finest level 

127 P = controller.MS[0].levels[0].prob 

128 uinit = P.u_exact(t0) 

129 

130 # call main function to get things done... 

131 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

132 

133 # compute exact solution and compare 

134 uex = P.u_exact(Tend) 

135 err = abs(uex - uend) 

136 

137 # filter statistics by variant (number of iterations) 

138 iter_counts = get_sorted(stats, type='niter', sortby='time') 

139 

140 # compute and print statistics 

141 niters = np.array([item[1] for item in iter_counts]) 

142 out = ' Mean number of iterations: %4.2f' % np.mean(niters) 

143 print(out) 

144 out = ' Range of values for number of iterations: %2i ' % np.ptp(niters) 

145 print(out) 

146 out = ' Position of max/min number of iterations: %2i -- %2i' % (int(np.argmax(niters)), int(np.argmin(niters))) 

147 print(out) 

148 out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters))) 

149 print(out) 

150 

151 print('Iteration count (nonlinear/linear): %i / %i' % (P.snes_itercount, P.ksp_itercount)) 

152 print( 

153 'Mean Iteration count per call: %4.2f / %4.2f' 

154 % (P.snes_itercount / max(P.snes_ncalls, 1), P.ksp_itercount / max(P.ksp_ncalls, 1)) 

155 ) 

156 

157 timing = get_sorted(stats, type='timing_run', sortby='time') 

158 

159 print('Time to solution: %6.4f sec.' % timing[0][1]) 

160 print('Error vs. PDE solution: %6.4e' % err) 

161 print() 

162 

163 assert err < 9.2e-05, 'ERROR: variant %s did not match error tolerance, got %s' % (variant, err) 

164 assert np.mean(niters) <= 10, 'ERROR: number of iterations is too high, got %s' % np.mean(niters) 

165 

166 return timing[0][1], np.mean(niters) 

167 

168 

169def show_results(fname): 

170 """ 

171 Plotting routine 

172 

173 Args: 

174 fname: file name to read in and name plots 

175 """ 

176 

177 file = open(fname + '.pkl', 'rb') 

178 results = pickle.load(file) 

179 file.close() 

180 

181 plt_helper.mpl.style.use('classic') 

182 plt_helper.setup_mpl() 

183 

184 plt_helper.newfig(textwidth=238.96, scale=1.0) 

185 

186 xcoords = list(range(len(results))) 

187 sorted_data = sorted([(key, results[key][0]) for key in results], reverse=True, key=lambda tup: tup[1]) 

188 heights = [item[1] for item in sorted_data] 

189 keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n') for item in sorted_data] 

190 

191 plt_helper.plt.bar(xcoords, heights, align='center') 

192 

193 plt_helper.plt.xticks(xcoords, keys, rotation=90) 

194 plt_helper.plt.ylabel('time (sec)') 

195 

196 # save plot, beautify 

197 plt_helper.savefig(fname) 

198 

199 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file' 

200 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file' 

201 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file' 

202 

203 return None 

204 

205 

206def main(cwd=''): 

207 """ 

208 Main driver 

209 

210 Args: 

211 cwd (str): current working directory (need this for testing) 

212 """ 

213 

214 # Loop over variants, exact and inexact solves 

215 results = {} 

216 for variant in ['fully-implicit', 'multi-implicit', 'semi-implicit']: 

217 results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False) 

218 results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True) 

219 

220 # dump result 

221 fname = cwd + 'data/timings_SDC_variants_Fisher' 

222 file = open(fname + '.pkl', 'wb') 

223 pickle.dump(results, file) 

224 file.close() 

225 assert os.path.isfile(fname + '.pkl'), 'ERROR: pickle did not create file' 

226 

227 # visualize 

228 show_results(fname) 

229 

230 

231if __name__ == "__main__": 

232 main()