Coverage for pySDC / projects / TOMS / AllenCahn_contracting_circle.py: 95%

184 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-13 09:00 +0000

1import os 

2 

3import dill 

4import matplotlib.ticker as ticker 

5import numpy as np 

6 

7import pySDC.helpers.plot_helper as plt_helper 

8from pySDC.helpers.stats_helper import get_sorted 

9 

10from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

11from pySDC.implementations.problem_classes.AllenCahn_2D_FD import ( 

12 allencahn_fullyimplicit, 

13 allencahn_semiimplicit, 

14 allencahn_semiimplicit_v2, 

15 allencahn_multiimplicit, 

16 allencahn_multiimplicit_v2, 

17) 

18from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

19from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

20from pySDC.implementations.sweeper_classes.multi_implicit import multi_implicit 

21from pySDC.projects.TOMS.AllenCahn_monitor import monitor 

22 

23# http://www.personal.psu.edu/qud2/Res/Pre/dz09sisc.pdf 

24 

25 

26def setup_parameters(): 

27 """ 

28 Helper routine to fill in all relevant parameters 

29 

30 Note that this file will be used for all versions of SDC, containing more than necessary for each individual run 

31 

32 Returns: 

33 description (dict) 

34 controller_params (dict) 

35 """ 

36 

37 # initialize level parameters 

38 level_params = dict() 

39 level_params['restol'] = 1e-08 

40 level_params['dt'] = 1e-03 

41 level_params['nsweeps'] = [1] 

42 

43 # initialize sweeper parameters 

44 sweeper_params = dict() 

45 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

46 sweeper_params['num_nodes'] = [3] 

47 sweeper_params['Q1'] = ['LU'] 

48 sweeper_params['Q2'] = ['LU'] 

49 sweeper_params['QI'] = ['LU'] 

50 sweeper_params['QE'] = ['EE'] 

51 sweeper_params['initial_guess'] = 'zero' 

52 

53 # This comes as read-in for the problem class 

54 problem_params = dict() 

55 problem_params['nu'] = 2 

56 problem_params['nvars'] = [(128, 128)] 

57 problem_params['eps'] = [0.04] 

58 problem_params['newton_maxiter'] = 100 

59 problem_params['newton_tol'] = 1e-09 

60 problem_params['lin_tol'] = 1e-10 

61 problem_params['lin_maxiter'] = 100 

62 problem_params['radius'] = 0.25 

63 

64 # initialize step parameters 

65 step_params = dict() 

66 step_params['maxiter'] = 50 

67 

68 # initialize controller parameters 

69 controller_params = dict() 

70 controller_params['logger_level'] = 30 

71 controller_params['hook_class'] = monitor 

72 

73 # fill description dictionary for easy step instantiation 

74 description = dict() 

75 description['problem_class'] = None # pass problem class 

76 description['problem_params'] = problem_params # pass problem parameters 

77 description['sweeper_class'] = None # pass sweeper (see part B) 

78 description['sweeper_params'] = sweeper_params # pass sweeper parameters 

79 description['level_params'] = level_params # pass level parameters 

80 description['step_params'] = step_params # pass step parameters 

81 

82 return description, controller_params 

83 

84 

85def run_SDC_variant(variant=None, inexact=False): 

86 """ 

87 Routine to run particular SDC variant 

88 

89 Args: 

90 variant (str): string describing the variant 

91 inexact (bool): flag to use inexact nonlinear solve (or nor) 

92 

93 Returns: 

94 timing (float) 

95 niter (float) 

96 """ 

97 

98 # load (incomplete) default parameters 

99 description, controller_params = setup_parameters() 

100 

101 # add stuff based on variant 

102 if variant == 'fully-implicit': 

103 description['problem_class'] = allencahn_fullyimplicit 

104 description['sweeper_class'] = generic_implicit 

105 if inexact: 

106 description['problem_params']['newton_maxiter'] = 1 

107 elif variant == 'semi-implicit': 

108 description['problem_class'] = allencahn_semiimplicit 

109 description['sweeper_class'] = imex_1st_order 

110 if inexact: 

111 description['problem_params']['lin_maxiter'] = 10 

112 elif variant == 'semi-implicit_v2': 

113 description['problem_class'] = allencahn_semiimplicit_v2 

114 description['sweeper_class'] = imex_1st_order 

115 if inexact: 

116 description['problem_params']['newton_maxiter'] = 1 

117 elif variant == 'multi-implicit': 

118 description['problem_class'] = allencahn_multiimplicit 

119 description['sweeper_class'] = multi_implicit 

120 if inexact: 

121 description['problem_params']['newton_maxiter'] = 1 

122 description['problem_params']['lin_maxiter'] = 10 

123 elif variant == 'multi-implicit_v2': 

124 description['problem_class'] = allencahn_multiimplicit_v2 

125 description['sweeper_class'] = multi_implicit 

126 if inexact: 

127 description['problem_params']['newton_maxiter'] = 1 

128 else: 

129 raise NotImplementedError('Wrong variant specified, got %s' % variant) 

130 

131 if inexact: 

132 out = 'Working on inexact %s variant...' % variant 

133 else: 

134 out = 'Working on exact %s variant...' % variant 

135 print(out) 

136 

137 # setup parameters "in time" 

138 t0 = 0 

139 Tend = 0.032 

140 

141 # instantiate controller 

142 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) 

143 

144 # get initial values on finest level 

145 P = controller.MS[0].levels[0].prob 

146 uinit = P.u_exact(t0) 

147 

148 # call main function to get things done... 

149 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

150 

151 # filter statistics by variant (number of iterations) 

152 iter_counts = get_sorted(stats, type='niter', sortby='time') 

153 

154 # compute and print statistics 

155 niters = np.array([item[1] for item in iter_counts]) 

156 out = ' Mean number of iterations: %4.2f' % np.mean(niters) 

157 print(out) 

158 out = ' Range of values for number of iterations: %2i ' % np.ptp(niters) 

159 print(out) 

160 out = ' Position of max/min number of iterations: %2i -- %2i' % (int(np.argmax(niters)), int(np.argmin(niters))) 

161 print(out) 

162 out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters))) 

163 print(out) 

164 

165 print(' Iteration count (nonlinear/linear): %i / %i' % (P.newton_itercount, P.lin_itercount)) 

166 print( 

167 ' Mean Iteration count per call: %4.2f / %4.2f' 

168 % (P.newton_itercount / max(P.newton_ncalls, 1), P.lin_itercount / max(P.lin_ncalls, 1)) 

169 ) 

170 

171 timing = get_sorted(stats, type='timing_run', sortby='time') 

172 

173 print('Time to solution: %6.4f sec.' % timing[0][1]) 

174 print() 

175 

176 return stats 

177 

178 

179def show_results(fname, cwd=''): 

180 """ 

181 Plotting routine 

182 

183 Args: 

184 fname (str): file name to read in and name plots 

185 cwd (str): current working directory 

186 """ 

187 

188 file = open(cwd + fname + '.pkl', 'rb') 

189 results = dill.load(file) 

190 file.close() 

191 

192 # plt_helper.mpl.style.use('classic') 

193 plt_helper.setup_mpl() 

194 

195 # set up plot for timings 

196 fig, ax1 = plt_helper.newfig(textwidth=238.96, scale=1.5, ratio=0.4) 

197 

198 timings = {} 

199 niters = {} 

200 for key, item in results.items(): 

201 timings[key] = get_sorted(item, type='timing_run', sortby='time')[0][1] 

202 iter_counts = get_sorted(item, type='niter', sortby='time') 

203 niters[key] = np.mean(np.array([item[1] for item in iter_counts])) 

204 

205 xcoords = list(range(len(timings))) 

206 sorted_timings = sorted([(key, timings[key]) for key in timings], reverse=True, key=lambda tup: tup[1]) 

207 sorted_niters = [(k, niters[k]) for k in [key[0] for key in sorted_timings]] 

208 heights_timings = [item[1] for item in sorted_timings] 

209 heights_niters = [item[1] for item in sorted_niters] 

210 keys = [(item[0][1] + ' ' + item[0][0]).replace('-', '\n').replace('_v2', ' mod.') for item in sorted_timings] 

211 

212 ax1.bar(xcoords, heights_timings, align='edge', width=-0.3, label='timings (left axis)') 

213 ax1.set_ylabel('time (sec)') 

214 

215 ax2 = ax1.twinx() 

216 ax2.bar(xcoords, heights_niters, color='lightcoral', align='edge', width=0.3, label='iterations (right axis)') 

217 ax2.set_ylabel('mean number of iterations') 

218 

219 ax1.set_xticks(xcoords) 

220 ax1.set_xticklabels(keys, rotation=90, ha='center') 

221 

222 # ask matplotlib for the plotted objects and their labels 

223 lines, labels = ax1.get_legend_handles_labels() 

224 lines2, labels2 = ax2.get_legend_handles_labels() 

225 ax2.legend(lines + lines2, labels + labels2, loc=0) 

226 

227 # save plot, beautify 

228 f = fname + '_timings' 

229 plt_helper.savefig(f) 

230 

231 assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file' 

232 # assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file' 

233 assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file' 

234 

235 # set up plot for radii 

236 fig, ax = plt_helper.newfig(textwidth=238.96, scale=1.0) 

237 

238 exact_radii = [] 

239 for key, item in results.items(): 

240 computed_radii = get_sorted(item, type='computed_radius', sortby='time') 

241 

242 xcoords = [item0[0] for item0 in computed_radii] 

243 radii = [item0[1] for item0 in computed_radii] 

244 if key[0] + ' ' + key[1] == 'fully-implicit exact': 

245 ax.plot(xcoords, radii, label=(key[0] + ' ' + key[1]).replace('_v2', ' mod.')) 

246 

247 exact_radii = get_sorted(item, type='exact_radius', sortby='time') 

248 

249 diff = np.array([abs(item0[1] - item1[1]) for item0, item1 in zip(exact_radii, computed_radii, strict=True)]) 

250 max_pos = int(np.argmax(diff)) 

251 assert max(diff) < 0.07, 'ERROR: computed radius is too far away from exact radius, got %s' % max(diff) 

252 assert 0.028 < computed_radii[max_pos][0] < 0.03, ( 

253 'ERROR: largest difference is at wrong time, got %s' % computed_radii[max_pos][0] 

254 ) 

255 

256 xcoords = [item[0] for item in exact_radii] 

257 radii = [item[1] for item in exact_radii] 

258 ax.plot(xcoords, radii, color='k', linestyle='--', linewidth=1, label='exact') 

259 

260 ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.2f')) 

261 ax.set_ylabel('radius') 

262 ax.set_xlabel('time') 

263 ax.grid() 

264 ax.legend(loc=3) 

265 

266 # save plot, beautify 

267 f = fname + '_radii' 

268 plt_helper.savefig(f) 

269 

270 assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file' 

271 # assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file' 

272 assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file' 

273 

274 # set up plot for interface width 

275 fig, ax = plt_helper.newfig(textwidth=238.96, scale=1.0) 

276 

277 interface_width = [] 

278 for key, item in results.items(): 

279 interface_width = get_sorted(item, type='interface_width', sortby='time') 

280 xcoords = [item[0] for item in interface_width] 

281 width = [item[1] for item in interface_width] 

282 if key[0] + ' ' + key[1] == 'fully-implicit exact': 

283 ax.plot(xcoords, width, label=key[0] + ' ' + key[1]) 

284 

285 xcoords = [item[0] for item in interface_width] 

286 init_width = [interface_width[0][1]] * len(xcoords) 

287 ax.plot(xcoords, init_width, color='k', linestyle='--', linewidth=1, label='exact') 

288 

289 ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%1.2f')) 

290 ax.set_ylabel(r'interface width ($\epsilon$)') 

291 ax.set_xlabel('time') 

292 ax.grid() 

293 ax.legend(loc=3) 

294 

295 # save plot, beautify 

296 f = fname + '_interface' 

297 plt_helper.savefig(f) 

298 

299 assert os.path.isfile(f + '.pdf'), 'ERROR: plotting did not create PDF file' 

300 # assert os.path.isfile(f + '.pgf'), 'ERROR: plotting did not create PGF file' 

301 assert os.path.isfile(f + '.png'), 'ERROR: plotting did not create PNG file' 

302 

303 return None 

304 

305 

306def main(cwd=''): 

307 """ 

308 Main driver 

309 

310 Args: 

311 cwd (str): current working directory (need this for testing) 

312 """ 

313 

314 # Loop over variants, exact and inexact solves 

315 results = {} 

316 for variant in ['multi-implicit', 'semi-implicit', 'fully-implicit', 'semi-implicit_v2', 'multi-implicit_v2']: 

317 results[(variant, 'exact')] = run_SDC_variant(variant=variant, inexact=False) 

318 results[(variant, 'inexact')] = run_SDC_variant(variant=variant, inexact=True) 

319 

320 # dump result 

321 fname = 'data/results_SDC_variants_AllenCahn_1E-03' 

322 file = open(cwd + fname + '.pkl', 'wb') 

323 dill.dump(results, file) 

324 file.close() 

325 assert os.path.isfile(cwd + fname + '.pkl'), 'ERROR: dill did not create file' 

326 

327 # visualize 

328 show_results(fname, cwd=cwd) 

329 

330 

331if __name__ == "__main__": 

332 main()