Coverage for pySDC/projects/Monodomain/run_scripts/run_TestODE.py: 90%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1from pathlib import Path 

2import numpy as np 

3import logging 

4import os 

5 

6from tqdm import tqdm 

7 

8from pySDC.core.errors import ParameterError 

9 

10from pySDC.projects.Monodomain.problem_classes.TestODE import MultiscaleTestODE 

11from pySDC.projects.Monodomain.transfer_classes.TransferVectorOfDCTVectors import TransferVectorOfDCTVectors 

12 

13from pySDC.projects.Monodomain.hooks.HookClass_post_iter_info import post_iter_info_hook 

14 

15from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

16 

17from pySDC.projects.Monodomain.sweeper_classes.exponential_runge_kutta.imexexp_1st_order import ( 

18 imexexp_1st_order as imexexp_1st_order_ExpRK, 

19) 

20from pySDC.projects.Monodomain.sweeper_classes.runge_kutta.imexexp_1st_order import imexexp_1st_order 

21 

22""" 

23Run the multirate Dahlquist test equation and plot the stability domain of the method. 

24We vary only the exponential term and the stiff term, while the non stiff term is kept constant (to allow 2D plots). 

25""" 

26 

27 

28def set_logger(controller_params): 

29 logging.basicConfig(level=controller_params["logger_level"]) 

30 hooks_logger = logging.getLogger("hooks") 

31 hooks_logger.setLevel(controller_params["logger_level"]) 

32 

33 

34def get_controller(controller_params, description, n_time_ranks): 

35 controller = controller_nonMPI(num_procs=n_time_ranks, controller_params=controller_params, description=description) 

36 return controller 

37 

38 

39def get_P_data(controller): 

40 P = controller.MS[0].levels[0].prob 

41 # set time parameters 

42 t0 = P.t0 

43 Tend = P.Tend 

44 uinit = P.initial_value() 

45 return t0, Tend, uinit, P 

46 

47 

48def get_base_transfer_params(): 

49 base_transfer_params = dict() 

50 base_transfer_params["finter"] = False 

51 return base_transfer_params 

52 

53 

54def get_controller_params(output_root, logger_level): 

55 controller_params = dict() 

56 controller_params["predict_type"] = "pfasst_burnin" 

57 controller_params["log_to_file"] = False 

58 controller_params["fname"] = output_root + "controller" 

59 controller_params["logger_level"] = logger_level 

60 controller_params["dump_setup"] = False 

61 controller_params["hook_class"] = [post_iter_info_hook] 

62 return controller_params 

63 

64 

65def get_description( 

66 integrator, problem_params, sweeper_params, level_params, step_params, base_transfer_params, space_transfer_class 

67): 

68 description = dict() 

69 

70 problem = MultiscaleTestODE 

71 

72 if integrator == "IMEXEXP": 

73 description["sweeper_class"] = imexexp_1st_order 

74 elif integrator == "IMEXEXP_EXPRK": 

75 description["sweeper_class"] = imexexp_1st_order_ExpRK 

76 else: 

77 raise ParameterError("Unknown integrator.") 

78 

79 description["problem_class"] = problem 

80 description["problem_params"] = problem_params 

81 description["sweeper_params"] = sweeper_params 

82 description["level_params"] = level_params 

83 description["step_params"] = step_params 

84 description["base_transfer_params"] = base_transfer_params 

85 description["space_transfer_class"] = space_transfer_class 

86 return description 

87 

88 

89def get_step_params(maxiter): 

90 step_params = dict() 

91 step_params["maxiter"] = maxiter 

92 return step_params 

93 

94 

95def get_level_params(dt, nsweeps, restol): 

96 # initialize level parameters 

97 level_params = dict() 

98 level_params["restol"] = restol 

99 level_params["dt"] = dt 

100 level_params["nsweeps"] = nsweeps 

101 level_params["residual_type"] = "full_rel" 

102 return level_params 

103 

104 

105def get_sweeper_params(num_nodes): 

106 # initialize sweeper parameters 

107 sweeper_params = dict() 

108 sweeper_params["initial_guess"] = "spread" 

109 sweeper_params["quad_type"] = "RADAU-RIGHT" 

110 sweeper_params["num_nodes"] = num_nodes 

111 sweeper_params["QI"] = "IE" 

112 

113 return sweeper_params 

114 

115 

116def get_output_root(): 

117 executed_file_dir = os.path.dirname(os.path.realpath(__file__)) 

118 output_root = executed_file_dir + "/../../../../data/Monodomain/results_tmp" 

119 return output_root 

120 

121 

122def get_problem_params(lmbda_laplacian, lmbda_gating, lmbda_others, end_time): 

123 # initialize problem parameters 

124 problem_params = dict() 

125 problem_params["output_file_name"] = "monodomain" 

126 problem_params["output_root"] = get_output_root() 

127 problem_params["end_time"] = end_time 

128 problem_params["lmbda_laplacian"] = lmbda_laplacian 

129 problem_params["lmbda_gating"] = lmbda_gating 

130 problem_params["lmbda_others"] = lmbda_others 

131 Path(problem_params["output_root"]).mkdir(parents=True, exist_ok=True) 

132 return problem_params 

133 

134 

135def plot_stability_domain(lmbda_laplacian_list, lmbda_gating_list, R, integrator, num_nodes, n_time_ranks): 

136 import matplotlib.pyplot as plt 

137 from matplotlib.colors import LogNorm 

138 import pySDC.helpers.plot_helper as plt_helper 

139 

140 plt_helper.setup_mpl() 

141 

142 # fig, ax = plt_helper.newfig(textwidth=400, scale=0.89, ratio=0.5) 

143 # fig, ax = plt_helper.newfig(textwidth=238.96, scale=0.89) 

144 fig, ax = plt_helper.plt.subplots( 

145 figsize=plt_helper.figsize(textwidth=400, scale=1.0, ratio=0.78), layout='constrained' 

146 ) 

147 

148 fs_label = 14 

149 fs_ticks = 12 

150 fs_title = 16 

151 X, Y = np.meshgrid(lmbda_gating_list, lmbda_laplacian_list) 

152 R = np.abs(R) 

153 CS = ax.contourf(X, Y, R, cmap=plt.cm.viridis, levels=np.logspace(-6, 0, 13), norm=LogNorm()) 

154 ax.plot(lmbda_gating_list, 0 * lmbda_gating_list, 'k--', linewidth=1.0) 

155 ax.plot(0 * lmbda_laplacian_list, lmbda_laplacian_list, 'k--', linewidth=1.0) 

156 ax.contour(CS, levels=CS.levels, colors='black') 

157 ax.set_xlabel(r'$z_{e}$', fontsize=fs_label, labelpad=-5) 

158 ax.set_ylabel(r'$z_{I}$', fontsize=fs_label, labelpad=-10) 

159 ax.tick_params(axis='x', labelsize=fs_ticks) 

160 ax.tick_params(axis='y', labelsize=fs_ticks) 

161 if len(num_nodes) == 1 and n_time_ranks == 1: 

162 prefix = "" 

163 elif len(num_nodes) > 1 and n_time_ranks == 1: 

164 prefix = "ML" 

165 elif len(num_nodes) > 1 and n_time_ranks > 1: 

166 prefix = "PFASST " 

167 if integrator == "IMEXEXP": 

168 ax.set_title(prefix + "SDC stability domain", fontsize=fs_title) 

169 elif integrator == "IMEXEXP_EXPRK": 

170 ax.set_title(prefix + "ESDC stability domain", fontsize=fs_title) 

171 ax.yaxis.tick_right() 

172 ax.yaxis.set_label_position("right") 

173 cbar = fig.colorbar(CS) 

174 cbar.ax.set_ylabel(r'$|R(z_e,z_{I})|$', fontsize=fs_label, labelpad=-20) 

175 cbar.set_ticks([cbar.vmin, cbar.vmax]) # keep only the ticks at the ends 

176 cbar.ax.tick_params(labelsize=fs_ticks) 

177 # plt_helper.plt.show() 

178 plt_helper.savefig("data/stability_domain_" + integrator, save_pdf=False, save_pgf=False, save_png=True) 

179 

180 

181def main(integrator, dl, l_min, openmp, n_time_ranks, end_time, num_nodes, check_stability): 

182 

183 # get time integration parameters 

184 # set maximum number of iterations in SDC/ESDC/MLSDC/etc 

185 step_params = get_step_params(maxiter=5) 

186 # set number of collocation nodes in each level 

187 sweeper_params = get_sweeper_params(num_nodes=num_nodes) 

188 # set step size, number of sweeps per iteration, and residual tolerance for the stopping criterion 

189 level_params = get_level_params(dt=1.0, nsweeps=[1], restol=5e-8) 

190 # set space transfer parameters 

191 # space_transfer_class = Transfer_myfloat 

192 space_transfer_class = TransferVectorOfDCTVectors 

193 base_transfer_params = get_base_transfer_params() 

194 controller_params = get_controller_params(get_output_root(), logger_level=40) 

195 

196 # set stability test parameters 

197 lmbda_others = -1.0 # the non stiff term 

198 lmbda_laplacian_min = l_min # the stiff term 

199 lmbda_laplacian_max = 0.0 

200 lmbda_gating_min = l_min # the exponential term 

201 lmbda_gating_max = 0.0 

202 

203 # define the grid for the stability domain 

204 n_lmbda_laplacian = np.round((lmbda_laplacian_max - lmbda_laplacian_min) / dl).astype(int) + 1 

205 n_lmbda_gating = np.round((lmbda_gating_max - lmbda_gating_min) / dl).astype(int) + 1 

206 lmbda_laplacian_list = np.linspace(lmbda_laplacian_min, lmbda_laplacian_max, n_lmbda_laplacian) 

207 lmbda_gating_list = np.linspace(lmbda_gating_min, lmbda_gating_max, n_lmbda_gating) 

208 

209 if not openmp: 

210 R = np.zeros((n_lmbda_laplacian, n_lmbda_gating)) 

211 for i in tqdm(range(n_lmbda_gating)): 

212 for j in range(n_lmbda_laplacian): 

213 lmbda_gating = lmbda_gating_list[i] 

214 lmbda_laplacian = lmbda_laplacian_list[j] 

215 

216 problem_params = get_problem_params( 

217 lmbda_laplacian=lmbda_laplacian, 

218 lmbda_gating=lmbda_gating, 

219 lmbda_others=lmbda_others, 

220 end_time=end_time, 

221 ) 

222 description = get_description( 

223 integrator, 

224 problem_params, 

225 sweeper_params, 

226 level_params, 

227 step_params, 

228 base_transfer_params, 

229 space_transfer_class, 

230 ) 

231 set_logger(controller_params) 

232 controller = get_controller(controller_params, description, n_time_ranks) 

233 

234 t0, Tend, uinit, P = get_P_data(controller) 

235 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

236 

237 R[j, i] = abs(uend) 

238 else: 

239 import pymp 

240 

241 R = pymp.shared.array((n_lmbda_laplacian, n_lmbda_gating), dtype=float) 

242 with pymp.Parallel(12) as p: 

243 for i in tqdm(p.range(0, n_lmbda_gating)): 

244 for j in range(n_lmbda_laplacian): 

245 lmbda_gating = lmbda_gating_list[i] 

246 lmbda_laplacian = lmbda_laplacian_list[j] 

247 

248 problem_params = get_problem_params( 

249 lmbda_laplacian=lmbda_laplacian, 

250 lmbda_gating=lmbda_gating, 

251 lmbda_others=lmbda_others, 

252 end_time=end_time, 

253 ) 

254 description = get_description( 

255 integrator, 

256 problem_params, 

257 sweeper_params, 

258 level_params, 

259 step_params, 

260 base_transfer_params, 

261 space_transfer_class, 

262 ) 

263 set_logger(controller_params) 

264 controller = get_controller(controller_params, description, n_time_ranks) 

265 

266 t0, Tend, uinit, P = get_P_data(controller) 

267 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

268 

269 R[j, i] = abs(uend) 

270 

271 plot_stability_domain(lmbda_laplacian_list, lmbda_gating_list, R, integrator, num_nodes, n_time_ranks) 

272 

273 if check_stability: 

274 assert ( 

275 np.max(np.abs(R.ravel())) <= 1.0 

276 ), "The maximum absolute value of the stability function is greater than 1.0." 

277 

278 

279if __name__ == "__main__": 

280 # Plot stability for exponential SDC coupled with the implicit-explicit-exponential integrator as preconditioner 

281 main( 

282 integrator="IMEXEXP_EXPRK", 

283 dl=2, 

284 l_min=-100, 

285 openmp=True, 

286 n_time_ranks=1, 

287 end_time=1.0, 

288 num_nodes=[5, 3], 

289 check_stability=True, # check that the stability function is bounded by 1.0 

290 ) 

291 # Plot stability for standard SDC coupled with the implicit-explicit-exponential integrator as preconditioner 

292 main( 

293 integrator="IMEXEXP", 

294 dl=2, 

295 l_min=-100, 

296 openmp=True, 

297 n_time_ranks=1, 

298 end_time=1.0, 

299 num_nodes=[5, 3], 

300 check_stability=False, # do not check for stability since we already know that the method is not stable 

301 )