Coverage for pySDC/projects/Monodomain/run_scripts/run_MonodomainODE.py: 87%

205 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1from pathlib import Path 

2import numpy as np 

3from mpi4py import MPI 

4import logging 

5import os 

6 

7from pySDC.core.errors import ParameterError 

8 

9from pySDC.projects.Monodomain.problem_classes.MonodomainODE import MultiscaleMonodomainODE 

10from pySDC.projects.Monodomain.hooks.HookClass_pde import pde_hook 

11from pySDC.projects.Monodomain.hooks.HookClass_post_iter_info import post_iter_info_hook 

12 

13from pySDC.helpers.stats_helper import get_sorted 

14 

15from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

16from pySDC.implementations.controller_classes.controller_MPI import controller_MPI 

17 

18from pySDC.projects.Monodomain.sweeper_classes.exponential_runge_kutta.imexexp_1st_order import ( 

19 imexexp_1st_order as imexexp_1st_order_ExpRK, 

20) 

21from pySDC.projects.Monodomain.sweeper_classes.runge_kutta.imexexp_1st_order import imexexp_1st_order 

22 

23from pySDC.projects.Monodomain.transfer_classes.TransferVectorOfDCTVectors import TransferVectorOfDCTVectors 

24 

25from pySDC.projects.Monodomain.utils.data_management import database 

26 

27 

28def set_logger(controller_params): 

29 logging.basicConfig(level=controller_params["logger_level"]) 

30 hooks_logger = logging.getLogger("hooks") 

31 hooks_logger.setLevel(controller_params["logger_level"]) 

32 

33 

34def get_controller(controller_params, description, time_comm, n_time_ranks, truly_time_parallel): 

35 if truly_time_parallel: 

36 controller = controller_MPI(controller_params=controller_params, description=description, comm=time_comm) 

37 else: 

38 controller = controller_nonMPI( 

39 num_procs=n_time_ranks, controller_params=controller_params, description=description 

40 ) 

41 return controller 

42 

43 

44def print_dofs_stats(time_rank, controller, P, uinit): 

45 tot_dofs = uinit.size 

46 mesh_dofs = uinit.shape[1] 

47 if time_rank == 0: 

48 controller.logger.info(f"Total dofs: {tot_dofs}, mesh dofs = {mesh_dofs}") 

49 

50 

51def get_P_data(controller, truly_time_parallel): 

52 if truly_time_parallel: 

53 P = controller.S.levels[0].prob 

54 else: 

55 P = controller.MS[0].levels[0].prob 

56 # set time parameters 

57 t0 = P.t0 

58 Tend = P.Tend 

59 uinit = P.initial_value() 

60 return t0, Tend, uinit, P 

61 

62 

63def get_comms(n_time_ranks, truly_time_parallel): 

64 if truly_time_parallel: 

65 time_comm = MPI.COMM_WORLD 

66 time_rank = time_comm.Get_rank() 

67 assert time_comm.Get_size() == n_time_ranks, "Number of time ranks does not match the number of MPI ranks" 

68 else: 

69 time_comm = None 

70 time_rank = 0 

71 return time_comm, time_rank 

72 

73 

74def get_base_transfer_params(finter): 

75 base_transfer_params = dict() 

76 base_transfer_params["finter"] = finter 

77 return base_transfer_params 

78 

79 

80def get_controller_params(problem_params, n_time_ranks): 

81 controller_params = dict() 

82 controller_params["predict_type"] = "pfasst_burnin" if n_time_ranks > 1 else None 

83 controller_params["log_to_file"] = False 

84 controller_params["fname"] = problem_params["output_root"] + "controller" 

85 controller_params["logger_level"] = 20 

86 controller_params["dump_setup"] = False 

87 if n_time_ranks == 1: 

88 controller_params["hook_class"] = [post_iter_info_hook, pde_hook] 

89 else: 

90 controller_params["hook_class"] = [post_iter_info_hook] 

91 return controller_params 

92 

93 

94def get_description( 

95 integrator, problem_params, sweeper_params, level_params, step_params, base_transfer_params, space_transfer_class 

96): 

97 description = dict() 

98 

99 problem = MultiscaleMonodomainODE 

100 

101 if integrator == "IMEXEXP": 

102 # implicit-explicit-exponential integrators in the preconditioner and standard SDC 

103 description["sweeper_class"] = imexexp_1st_order 

104 elif integrator == "IMEXEXP_EXPRK": 

105 # implicit-explicit-exponential integrators in the preconditioner and exponential SDC 

106 description["sweeper_class"] = imexexp_1st_order_ExpRK 

107 else: 

108 raise ParameterError("Unknown integrator.") 

109 

110 description["problem_class"] = problem 

111 description["problem_params"] = problem_params 

112 description["sweeper_params"] = sweeper_params 

113 description["level_params"] = level_params 

114 description["step_params"] = step_params 

115 description["base_transfer_params"] = base_transfer_params 

116 description["space_transfer_class"] = space_transfer_class 

117 

118 return description 

119 

120 

121def get_step_params(maxiter): 

122 step_params = dict() 

123 step_params["maxiter"] = maxiter 

124 return step_params 

125 

126 

127def get_level_params(dt, nsweeps, restol, n_time_ranks): 

128 # initialize level parameters 

129 level_params = dict() 

130 level_params["restol"] = restol 

131 level_params["dt"] = dt 

132 level_params["nsweeps"] = nsweeps 

133 level_params["residual_type"] = "full_rel" 

134 level_params["parallel"] = n_time_ranks > 1 

135 

136 return level_params 

137 

138 

139def get_sweeper_params(num_nodes, skip_residual_computation): 

140 # initialize sweeper parameters 

141 sweeper_params = dict() 

142 sweeper_params["initial_guess"] = "spread" 

143 sweeper_params["quad_type"] = "RADAU-RIGHT" 

144 sweeper_params["num_nodes"] = num_nodes 

145 sweeper_params["QI"] = "IE" 

146 if skip_residual_computation: 

147 sweeper_params["skip_residual_computation"] = ("IT_FINE", "IT_COARSE", "IT_DOWN", "IT_UP") 

148 

149 return sweeper_params 

150 

151 

152def get_space_tranfer_params(): 

153 

154 space_transfer_class = TransferVectorOfDCTVectors 

155 

156 return space_transfer_class 

157 

158 

159def get_problem_params( 

160 domain_name, 

161 refinements, 

162 ionic_model_name, 

163 read_init_val, 

164 init_time, 

165 enable_output, 

166 end_time, 

167 order, 

168 output_root, 

169 output_file_name, 

170 ref_sol, 

171): 

172 # initialize problem parameters 

173 problem_params = dict() 

174 problem_params["order"] = order # order of the spatial discretization 

175 problem_params["refinements"] = refinements # number of refinements with respect to a baseline 

176 problem_params["domain_name"] = ( 

177 domain_name # name of the domain: cube_1D, cube_2D, cube_3D, cuboid_1D, cuboid_2D, cuboid_3D, cuboid_1D_small, cuboid_2D_small, cuboid_3D_small 

178 ) 

179 problem_params["ionic_model_name"] = ( 

180 ionic_model_name # name of the ionic model: HH, CRN, TTP, TTP_SMOOTH for Hodgkin-Huxley, Courtemanche-Ramirez-Nattel, Ten Tusscher-Panfilov and a smoothed version of Ten Tusscher-Panfilov 

181 ) 

182 problem_params["read_init_val"] = ( 

183 read_init_val # read the initial value from file (True) or initiate an action potential with a stimulus (False) 

184 ) 

185 problem_params["init_time"] = ( 

186 init_time # stimulus happpens at t=0 and t=1000 and lasts 2ms. If init_time>2 nothing happens up to t=1000. If init_time>1002 nothing happens, never. 

187 ) 

188 problem_params["init_val_name"] = "init_val_DCT" # name of the file containing the initial value 

189 problem_params["enable_output"] = ( 

190 enable_output # activate or deactivate output (that can be visualized with visualization/show_monodomain_sol.py) 

191 ) 

192 problem_params["output_V_only"] = ( 

193 True # output only the transmembrane potential (V) and not the ionic model variables 

194 ) 

195 executed_file_dir = os.path.dirname(os.path.realpath(__file__)) 

196 problem_params["output_root"] = ( 

197 executed_file_dir + "/../../../../data/" + output_root 

198 ) # output root folder. A hierarchy of folders is created in this folder, as root/domain_name/ref_+str(refinements)/ionic_model_name. Initial values are put here 

199 problem_params["output_file_name"] = output_file_name 

200 problem_params["ref_sol"] = ref_sol # reference solution file name 

201 problem_params["end_time"] = end_time 

202 Path(problem_params["output_root"]).mkdir(parents=True, exist_ok=True) 

203 

204 return problem_params 

205 

206 

207def setup_and_run( 

208 integrator, 

209 num_nodes, 

210 skip_residual_computation, 

211 num_sweeps, 

212 max_iter, 

213 dt, 

214 restol, 

215 domain_name, 

216 refinements, 

217 order, 

218 ionic_model_name, 

219 read_init_val, 

220 init_time, 

221 enable_output, 

222 write_as_reference_solution, 

223 write_all_variables, 

224 output_root, 

225 output_file_name, 

226 ref_sol, 

227 end_time, 

228 truly_time_parallel, 

229 n_time_ranks, 

230 finter, 

231 write_database, 

232): 

233 

234 # get time communicator 

235 time_comm, time_rank = get_comms(n_time_ranks, truly_time_parallel) 

236 

237 # get time integration parameters 

238 # set maximum number of iterations in ESDC/MLESDC/PFASST 

239 step_params = get_step_params(maxiter=max_iter) 

240 # set number of collocation nodes in each level 

241 sweeper_params = get_sweeper_params(num_nodes=num_nodes, skip_residual_computation=skip_residual_computation) 

242 # set step size, number of sweeps per iteration, and residual tolerance for the stopping criterion 

243 level_params = get_level_params( 

244 dt=dt, 

245 nsweeps=num_sweeps, 

246 restol=restol, 

247 n_time_ranks=n_time_ranks, 

248 ) 

249 

250 # fix enable output to that only finest level has output 

251 n_levels = max(len(refinements), len(num_nodes)) 

252 enable_output = [enable_output] + [False] * (n_levels - 1) 

253 # get problem parameters 

254 problem_params = get_problem_params( 

255 domain_name=domain_name, 

256 refinements=refinements, 

257 ionic_model_name=ionic_model_name, 

258 read_init_val=read_init_val, 

259 init_time=init_time, 

260 enable_output=enable_output, 

261 end_time=end_time, 

262 order=order, 

263 output_root=output_root, 

264 output_file_name=output_file_name, 

265 ref_sol=ref_sol, 

266 ) 

267 

268 space_transfer_class = get_space_tranfer_params() 

269 

270 # get remaining prams 

271 base_transfer_params = get_base_transfer_params(finter) 

272 controller_params = get_controller_params(problem_params, n_time_ranks) 

273 description = get_description( 

274 integrator, 

275 problem_params, 

276 sweeper_params, 

277 level_params, 

278 step_params, 

279 base_transfer_params, 

280 space_transfer_class, 

281 ) 

282 set_logger(controller_params) 

283 controller = get_controller(controller_params, description, time_comm, n_time_ranks, truly_time_parallel) 

284 

285 # get PDE data 

286 t0, Tend, uinit, P = get_P_data(controller, truly_time_parallel) 

287 

288 # print dofs stats 

289 print_dofs_stats(time_rank, controller, P, uinit) 

290 

291 # run 

292 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

293 

294 # write reference solution, to be used later for error computation 

295 if write_as_reference_solution: 

296 P.write_reference_solution(uend, write_all_variables) 

297 

298 # compute errors, if a reference solution is found 

299 error_availabe, error_L2, rel_error_L2 = P.compute_errors(uend) 

300 

301 # get some stats 

302 iter_counts = get_sorted(stats, type="niter", sortby="time") 

303 residuals = get_sorted(stats, type="residual_post_iteration", sortby="time") 

304 if time_comm is not None: 

305 iter_counts = time_comm.gather(iter_counts, root=0) 

306 residuals = time_comm.gather(residuals, root=0) 

307 if time_rank == 0: 

308 iter_counts = [item for sublist in iter_counts for item in sublist] 

309 residuals = [item for sublist in residuals for item in sublist] 

310 iter_counts = time_comm.bcast(iter_counts, root=0) 

311 residuals = time_comm.bcast(residuals, root=0) 

312 

313 iter_counts.sort() 

314 times = [item[0] for item in iter_counts] 

315 niters = [item[1] for item in iter_counts] 

316 

317 residuals.sort() 

318 residuals_new = [residuals[0][1]] 

319 t = residuals[0][0] 

320 for i in range(1, len(residuals)): 

321 if residuals[i][0] > t + dt / 2.0: 

322 residuals_new.append(residuals[i][1]) 

323 t = residuals[i][0] 

324 residuals = residuals_new 

325 

326 avg_niters = np.mean(niters) 

327 if time_rank == 0: 

328 controller.logger.info("Mean number of iterations: %4.2f" % avg_niters) 

329 controller.logger.info( 

330 "Std and var for number of iterations: %4.2f -- %4.2f" % (float(np.std(niters)), float(np.var(niters))) 

331 ) 

332 

333 if write_database and time_rank == 0: 

334 errors = dict() 

335 errors["error_L2"] = error_L2 

336 errors["rel_error_L2"] = rel_error_L2 

337 iters_info = dict() 

338 iters_info["avg_niters"] = avg_niters 

339 iters_info["times"] = times 

340 iters_info["niters"] = niters 

341 iters_info["residuals"] = residuals 

342 file_name = P.output_folder / Path(P.output_file_name) 

343 if file_name.with_suffix('.db').is_file(): 

344 os.remove(file_name.with_suffix('.db')) 

345 data_man = database(file_name) 

346 data_man.write_dictionary("errors", errors) 

347 data_man.write_dictionary("iters_info", iters_info) 

348 

349 return error_L2, rel_error_L2, avg_niters, times, niters, residuals 

350 

351 

352def main(): 

353 # define sweeper parameters 

354 # integrator = "IMEXEXP" 

355 integrator = "IMEXEXP_EXPRK" 

356 num_nodes = [4] 

357 num_sweeps = [1] 

358 

359 # set step parameters 

360 max_iter = 100 

361 

362 # set level parameters 

363 dt = 0.05 

364 restol = 5e-8 

365 

366 # set problem parameters 

367 domain_name = "cube_2D" 

368 refinements = [-1] 

369 order = 4 # 2 or 4 

370 ionic_model_name = "TTP" 

371 read_init_val = True 

372 init_time = 3.0 

373 enable_output = False 

374 write_as_reference_solution = False 

375 write_all_variables = False 

376 write_database = False 

377 end_time = 0.05 

378 output_root = "results_tmp" 

379 output_file_name = "ref_sol" if write_as_reference_solution else "monodomain" 

380 ref_sol = "ref_sol" 

381 skip_residual_computation = False 

382 

383 finter = False 

384 

385 # set time parallelism to True or emulated (False) 

386 truly_time_parallel = False 

387 n_time_ranks = 1 

388 

389 error_L2, rel_error_L2, avg_niters, times, niters, residuals = setup_and_run( 

390 integrator, 

391 num_nodes, 

392 skip_residual_computation, 

393 num_sweeps, 

394 max_iter, 

395 dt, 

396 restol, 

397 domain_name, 

398 refinements, 

399 order, 

400 ionic_model_name, 

401 read_init_val, 

402 init_time, 

403 enable_output, 

404 write_as_reference_solution, 

405 write_all_variables, 

406 output_root, 

407 output_file_name, 

408 ref_sol, 

409 end_time, 

410 truly_time_parallel, 

411 n_time_ranks, 

412 finter, 

413 write_database, 

414 ) 

415 

416 

417if __name__ == "__main__": 

418 main()