Coverage for pySDC/projects/TOMS/AllenCahn_contracting_circle.py: 95%

184 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import os 

2 

3import dill 

4import matplotlib.ticker as ticker 

5import numpy as np 

6 

7import pySDC.helpers.plot_helper as plt_helper 

8from pySDC.helpers.stats_helper import get_sorted 

9 

10from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

11from pySDC.implementations.problem_classes.AllenCahn_2D_FD import ( 

12 allencahn_fullyimplicit, 

13 allencahn_semiimplicit, 

14 allencahn_semiimplicit_v2, 

15 allencahn_multiimplicit, 

16 allencahn_multiimplicit_v2, 

17) 

18from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

19from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

20from pySDC.implementations.sweeper_classes.multi_implicit import multi_implicit 

21from pySDC.projects.TOMS.AllenCahn_monitor import monitor 

22 

23 

24# http://www.personal.psu.edu/qud2/Res/Pre/dz09sisc.pdf 

25 

26 

27def setup_parameters(): 

28 """ 

29 Helper routine to fill in all relevant parameters 

30 

31 Note that this file will be used for all versions of SDC, containing more than necessary for each individual run 

32 

33 Returns: 

34 description (dict) 

35 controller_params (dict) 

36 """ 

37 

38 # initialize level parameters 

39 level_params = dict() 

40 level_params['restol'] = 1e-08 

41 level_params['dt'] = 1e-03 

42 level_params['nsweeps'] = [1] 

43 

44 # initialize sweeper parameters 

45 sweeper_params = dict() 

46 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

47 sweeper_params['num_nodes'] = [3] 

48 sweeper_params['Q1'] = ['LU'] 

49 sweeper_params['Q2'] = ['LU'] 

50 sweeper_params['QI'] = ['LU'] 

51 sweeper_params['QE'] = ['EE'] 

52 sweeper_params['initial_guess'] = 'zero' 

53 

54 # This comes as read-in for the problem class 

55 problem_params = dict() 

56 problem_params['nu'] = 2 

57 problem_params['nvars'] = [(128, 128)] 

58 problem_params['eps'] = [0.04] 

59 problem_params['newton_maxiter'] = 100 

60 problem_params['newton_tol'] = 1e-09 

61 problem_params['lin_tol'] = 1e-10 

62 problem_params['lin_maxiter'] = 100 

63 problem_params['radius'] = 0.25 

64 

65 # initialize step parameters 

66 step_params = dict() 

67 step_params['maxiter'] = 50 

68 

69 # initialize controller parameters 

70 controller_params = dict() 

71 controller_params['logger_level'] = 30 

72 controller_params['hook_class'] = monitor 

73 

74 # fill description dictionary for easy step instantiation 

75 description = dict() 

76 description['problem_class'] = None # pass problem class 

77 description['problem_params'] = problem_params # pass problem parameters 

78 description['sweeper_class'] = None # pass sweeper (see part B) 

79 description['sweeper_params'] = sweeper_params # pass sweeper parameters 

80 description['level_params'] = level_params # pass level parameters 

81 description['step_params'] = step_params # pass step parameters 

82 

83 return description, controller_params 

84 

85 

86def run_SDC_variant(variant=None, inexact=False): 

87 """ 

88 Routine to run particular SDC variant 

89 

90 Args: 

91 variant (str): string describing the variant 

92 inexact (bool): flag to use inexact nonlinear solve (or nor) 

93 

94 Returns: 

95 timing (float) 

96 niter (float) 

97 """ 

98 

99 # load (incomplete) default parameters 

100 description, controller_params = setup_parameters() 

101 

102 # add stuff based on variant 

103 if variant == 'fully-implicit': 

104 description['problem_class'] = allencahn_fullyimplicit 

105 description['sweeper_class'] = generic_implicit 

106 if inexact: 

107 description['problem_params']['newton_maxiter'] = 1 

108 elif variant == 'semi-implicit': 

109 description['problem_class'] = allencahn_semiimplicit 

110 description['sweeper_class'] = imex_1st_order 

111 if inexact: 

112 description['problem_params']['lin_maxiter'] = 10 

113 elif variant == 'semi-implicit_v2': 

114 description['problem_class'] = allencahn_semiimplicit_v2 

115 description['sweeper_class'] = imex_1st_order 

116 if inexact: 

117 description['problem_params']['newton_maxiter'] = 1 

118 elif variant == 'multi-implicit': 

119 description['problem_class'] = allencahn_multiimplicit 

120 description['sweeper_class'] = multi_implicit 

121 if inexact: 

122 description['problem_params']['newton_maxiter'] = 1 

123 description['problem_params']['lin_maxiter'] = 10 

124 elif variant == 'multi-implicit_v2': 

125 description['problem_class'] = allencahn_multiimplicit_v2 

126 description['sweeper_class'] = multi_implicit 

127 if inexact: 

128 description['problem_params']['newton_maxiter'] = 1 

129 else: 

130 raise NotImplementedError('Wrong variant specified, got %s' % variant) 

131 

132 if inexact: 

133 out = 'Working on inexact %s variant...' % variant 

134 else: 

135 out = 'Working on exact %s variant...' % variant 

136 print(out) 

137 

138 # setup parameters "in time" 

139 t0 = 0 

140 Tend = 0.032 

141 

142 # instantiate controller 

143 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) 

144 

145 # get initial values on finest level 

146 P = controller.MS[0].levels[0].prob 

147 uinit = P.u_exact(t0) 

148 

149 # call main function to get things done... 

150 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

151 

152 # filter statistics by variant (number of iterations) 

153 iter_counts = get_sorted(stats, type='niter', sortby='time') 

154 

155 # compute and print statistics 

156 niters = np.array([item[1] for item in iter_counts]) 

157 out = ' Mean number of iterations: %4.2f' % np.mean(niters) 

158 print(out) 

159 out = ' Range of values for number of iterations: %2i ' % np.ptp(niters) 

160 print(out) 

161 out = ' Position of max/min number of iterations: %2i -- %2i' % (int(np.argmax(niters)), int(np.argmin(niters))) 

162 print(out) 

163 out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters))) 

164 print(out) 

165 

166 print(' Iteration count (nonlinear/linear): %i / %i' % (P.newton_itercount, P.lin_itercount)) 

167 print( 

168 ' Mean Iteration count per call: %4.2f / %4.2f' 

169 % (P.newton_itercount / max(P.newton_ncalls, 1), P.lin_itercount / max(P.lin_ncalls, 1)) 

170 ) 

171 

172 timing = get_sorted(stats, type='timing_run', sortby='time') 

173 

174 print('Time to solution: %6.4f sec.' % timing[0][1]) 

175 print() 

176 

177 return stats 

178 

179 

180def show_results(fname, cwd=''): 

181 """ 

182 Plotting routine 

183 

184 Args: 

185 fname (str): file name to read in and name plots 

186 cwd (str): current working directory 

187 """ 

188 

189 file = open(cwd + fname + '.pkl', 'rb') 

190 results = dill.load(file) 

191 file.close() 

192 

193 # plt_helper.mpl.style.use('classic') 

194 plt_helper.setup_mpl() 

195 

196 # set up plot for timings 

197 fig, ax1 = plt_helper.newfig(textwidth=238.96, scale=1.5, ratio=0.4) 

198 

199 timings = {} 

200 niters = {} 

201 for key, item in results.items(): 

202 timings[key] = get_sorted(item, type='timing_run', sortby='time')[0][1] 

203 iter_counts = get_sorted(item, type='niter', sortby='time') 

204 niters[key] = np.mean(np.array([item[1] for item in iter_counts])) 

205 

206 xcoords = list(range(len(timings))) 

207 sorted_timings = sorted([(key, timings[key]) for key in timings], reverse=True, key=lambda tup: tup[1]) 

208 sorted_niters = [(k, niters[k]) for k in [key[0] for key in sorted_timings]] 

209 heights_timings = [item[1] for item in sorted_timings] 

210 heights_niters = [item[1] for item in sorted_niters] 

211 keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n').replace('_v2', ' mod.') for item in sorted_timings] 

212 

213 ax1.bar(xcoords, heights_timings, align='edge', width=-0.3, label='timings (left axis)') 

214 ax1.set_ylabel('time (sec)') 

215 

216 ax2 = ax1.twinx() 

217 ax2.bar(xcoords, heights_niters, color='lightcoral', align='edge', width=0.3, label='iterations (right axis)') 

218 ax2.set_ylabel('mean number of iterations') 

219 

220 ax1.set_xticks(xcoords) 

221 ax1.set_xticklabels(keys, rotation=90, ha='center') 

222 

223 # ask matplotlib for the plotted objects and their labels 

224 lines, labels = ax1.get_legend_handles_labels() 

225 lines2, labels2 = ax2.get_legend_handles_labels() 

226 ax2.legend(lines + lines2, labels + labels2, loc=0) 

227 

228 # save plot, beautify 

229 f = fname + '_timings' 

230 plt_helper.savefig(f) 

231 

232 assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file' 

233 # assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file' 

234 assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file' 

235 

236 # set up plot for radii 

237 fig, ax = plt_helper.newfig(textwidth=238.96, scale=1.0) 

238 

239 exact_radii = [] 

240 for key, item in results.items(): 

241 computed_radii = get_sorted(item, type='computed_radius', sortby='time') 

242 

243 xcoords = [item0[0] for item0 in computed_radii] 

244 radii = [item0[1] for item0 in computed_radii] 

245 if key[0] + ' ' + key[1] == 'fully-implicit exact': 

246 ax.plot(xcoords, radii, label=(key[0] + ' ' + key[1]).replace('_v2', ' mod.')) 

247 

248 exact_radii = get_sorted(item, type='exact_radius', sortby='time') 

249 

250 diff = np.array([abs(item0[1] - item1[1]) for item0, item1 in zip(exact_radii, computed_radii)]) 

251 max_pos = int(np.argmax(diff)) 

252 assert max(diff) < 0.07, 'ERROR: computed radius is too far away from exact radius, got %s' % max(diff) 

253 assert 0.028 < computed_radii[max_pos][0] < 0.03, ( 

254 'ERROR: largest difference is at wrong time, got %s' % computed_radii[max_pos][0] 

255 ) 

256 

257 xcoords = [item[0] for item in exact_radii] 

258 radii = [item[1] for item in exact_radii] 

259 ax.plot(xcoords, radii, color='k', linestyle='--', linewidth=1, label='exact') 

260 

261 ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.2f')) 

262 ax.set_ylabel('radius') 

263 ax.set_xlabel('time') 

264 ax.grid() 

265 ax.legend(loc=3) 

266 

267 # save plot, beautify 

268 f = fname + '_radii' 

269 plt_helper.savefig(f) 

270 

271 assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file' 

272 # assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file' 

273 assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file' 

274 

275 # set up plot for interface width 

276 fig, ax = plt_helper.newfig(textwidth=238.96, scale=1.0) 

277 

278 interface_width = [] 

279 for key, item in results.items(): 

280 interface_width = get_sorted(item, type='interface_width', sortby='time') 

281 xcoords = [item[0] for item in interface_width] 

282 width = [item[1] for item in interface_width] 

283 if key[0] + ' ' + key[1] == 'fully-implicit exact': 

284 ax.plot(xcoords, width, label=key[0] + ' ' + key[1]) 

285 

286 xcoords = [item[0] for item in interface_width] 

287 init_width = [interface_width[0][1]] * len(xcoords) 

288 ax.plot(xcoords, init_width, color='k', linestyle='--', linewidth=1, label='exact') 

289 

290 ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.2f')) 

291 ax.set_ylabel(r'interface width ($\epsilon$)') 

292 ax.set_xlabel('time') 

293 ax.grid() 

294 ax.legend(loc=3) 

295 

296 # save plot, beautify 

297 f = fname + '_interface' 

298 plt_helper.savefig(f) 

299 

300 assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file' 

301 # assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file' 

302 assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file' 

303 

304 return None 

305 

306 

307def main(cwd=''): 

308 """ 

309 Main driver 

310 

311 Args: 

312 cwd (str): current working directory (need this for testing) 

313 """ 

314 

315 # Loop over variants, exact and inexact solves 

316 results = {} 

317 for variant in ['multi-implicit', 'semi-implicit', 'fully-implicit', 'semi-implicit_v2', 'multi-implicit_v2']: 

318 results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False) 

319 results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True) 

320 

321 # dump result 

322 fname = 'data/results_SDC_variants_AllenCahn_1E-03' 

323 file = open(cwd + fname + '.pkl', 'wb') 

324 dill.dump(results, file) 

325 file.close() 

326 assert os.path.isfile(cwd + fname + '.pkl'), 'ERROR: dill did not create file' 

327 

328 # visualize 

329 show_results(fname, cwd=cwd) 

330 

331 

332if __name__ == "__main__": 

333 main()