Coverage for pySDC/projects/Monodomain/run_scripts/run_MonodomainODE.py: 87%
205 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1from pathlib import Path
2import numpy as np
3from mpi4py import MPI
4import logging
5import os
7from pySDC.core.errors import ParameterError
9from pySDC.projects.Monodomain.problem_classes.MonodomainODE import MultiscaleMonodomainODE
10from pySDC.projects.Monodomain.hooks.HookClass_pde import pde_hook
11from pySDC.projects.Monodomain.hooks.HookClass_post_iter_info import post_iter_info_hook
13from pySDC.helpers.stats_helper import get_sorted
15from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
16from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
18from pySDC.projects.Monodomain.sweeper_classes.exponential_runge_kutta.imexexp_1st_order import (
19 imexexp_1st_order as imexexp_1st_order_ExpRK,
20)
21from pySDC.projects.Monodomain.sweeper_classes.runge_kutta.imexexp_1st_order import imexexp_1st_order
23from pySDC.projects.Monodomain.transfer_classes.TransferVectorOfDCTVectors import TransferVectorOfDCTVectors
25from pySDC.projects.Monodomain.utils.data_management import database
28def set_logger(controller_params):
29 logging.basicConfig(level=controller_params["logger_level"])
30 hooks_logger = logging.getLogger("hooks")
31 hooks_logger.setLevel(controller_params["logger_level"])
34def get_controller(controller_params, description, time_comm, n_time_ranks, truly_time_parallel):
35 if truly_time_parallel:
36 controller = controller_MPI(controller_params=controller_params, description=description, comm=time_comm)
37 else:
38 controller = controller_nonMPI(
39 num_procs=n_time_ranks, controller_params=controller_params, description=description
40 )
41 return controller
44def print_dofs_stats(time_rank, controller, P, uinit):
45 tot_dofs = uinit.size
46 mesh_dofs = uinit.shape[1]
47 if time_rank == 0:
48 controller.logger.info(f"Total dofs: {tot_dofs}, mesh dofs = {mesh_dofs}")
51def get_P_data(controller, truly_time_parallel):
52 if truly_time_parallel:
53 P = controller.S.levels[0].prob
54 else:
55 P = controller.MS[0].levels[0].prob
56 # set time parameters
57 t0 = P.t0
58 Tend = P.Tend
59 uinit = P.initial_value()
60 return t0, Tend, uinit, P
63def get_comms(n_time_ranks, truly_time_parallel):
64 if truly_time_parallel:
65 time_comm = MPI.COMM_WORLD
66 time_rank = time_comm.Get_rank()
67 assert time_comm.Get_size() == n_time_ranks, "Number of time ranks does not match the number of MPI ranks"
68 else:
69 time_comm = None
70 time_rank = 0
71 return time_comm, time_rank
74def get_base_transfer_params(finter):
75 base_transfer_params = dict()
76 base_transfer_params["finter"] = finter
77 return base_transfer_params
80def get_controller_params(problem_params, n_time_ranks):
81 controller_params = dict()
82 controller_params["predict_type"] = "pfasst_burnin" if n_time_ranks > 1 else None
83 controller_params["log_to_file"] = False
84 controller_params["fname"] = problem_params["output_root"] + "controller"
85 controller_params["logger_level"] = 20
86 controller_params["dump_setup"] = False
87 if n_time_ranks == 1:
88 controller_params["hook_class"] = [post_iter_info_hook, pde_hook]
89 else:
90 controller_params["hook_class"] = [post_iter_info_hook]
91 return controller_params
94def get_description(
95 integrator, problem_params, sweeper_params, level_params, step_params, base_transfer_params, space_transfer_class
96):
97 description = dict()
99 problem = MultiscaleMonodomainODE
101 if integrator == "IMEXEXP":
102 # implicit-explicit-exponential integrators in the preconditioner and standard SDC
103 description["sweeper_class"] = imexexp_1st_order
104 elif integrator == "IMEXEXP_EXPRK":
105 # implicit-explicit-exponential integrators in the preconditioner and exponential SDC
106 description["sweeper_class"] = imexexp_1st_order_ExpRK
107 else:
108 raise ParameterError("Unknown integrator.")
110 description["problem_class"] = problem
111 description["problem_params"] = problem_params
112 description["sweeper_params"] = sweeper_params
113 description["level_params"] = level_params
114 description["step_params"] = step_params
115 description["base_transfer_params"] = base_transfer_params
116 description["space_transfer_class"] = space_transfer_class
118 return description
121def get_step_params(maxiter):
122 step_params = dict()
123 step_params["maxiter"] = maxiter
124 return step_params
127def get_level_params(dt, nsweeps, restol, n_time_ranks):
128 # initialize level parameters
129 level_params = dict()
130 level_params["restol"] = restol
131 level_params["dt"] = dt
132 level_params["nsweeps"] = nsweeps
133 level_params["residual_type"] = "full_rel"
134 level_params["parallel"] = n_time_ranks > 1
136 return level_params
139def get_sweeper_params(num_nodes, skip_residual_computation):
140 # initialize sweeper parameters
141 sweeper_params = dict()
142 sweeper_params["initial_guess"] = "spread"
143 sweeper_params["quad_type"] = "RADAU-RIGHT"
144 sweeper_params["num_nodes"] = num_nodes
145 sweeper_params["QI"] = "IE"
146 if skip_residual_computation:
147 sweeper_params["skip_residual_computation"] = ("IT_FINE", "IT_COARSE", "IT_DOWN", "IT_UP")
149 return sweeper_params
152def get_space_tranfer_params():
154 space_transfer_class = TransferVectorOfDCTVectors
156 return space_transfer_class
159def get_problem_params(
160 domain_name,
161 refinements,
162 ionic_model_name,
163 read_init_val,
164 init_time,
165 enable_output,
166 end_time,
167 order,
168 output_root,
169 output_file_name,
170 ref_sol,
171):
172 # initialize problem parameters
173 problem_params = dict()
174 problem_params["order"] = order # order of the spatial discretization
175 problem_params["refinements"] = refinements # number of refinements with respect to a baseline
176 problem_params["domain_name"] = (
177 domain_name # name of the domain: cube_1D, cube_2D, cube_3D, cuboid_1D, cuboid_2D, cuboid_3D, cuboid_1D_small, cuboid_2D_small, cuboid_3D_small
178 )
179 problem_params["ionic_model_name"] = (
180 ionic_model_name # name of the ionic model: HH, CRN, TTP, TTP_SMOOTH for Hodgkin-Huxley, Courtemanche-Ramirez-Nattel, Ten Tusscher-Panfilov and a smoothed version of Ten Tusscher-Panfilov
181 )
182 problem_params["read_init_val"] = (
183 read_init_val # read the initial value from file (True) or initiate an action potential with a stimulus (False)
184 )
185 problem_params["init_time"] = (
186 init_time # stimulus happpens at t=0 and t=1000 and lasts 2ms. If init_time>2 nothing happens up to t=1000. If init_time>1002 nothing happens, never.
187 )
188 problem_params["init_val_name"] = "init_val_DCT" # name of the file containing the initial value
189 problem_params["enable_output"] = (
190 enable_output # activate or deactivate output (that can be visualized with visualization/show_monodomain_sol.py)
191 )
192 problem_params["output_V_only"] = (
193 True # output only the transmembrane potential (V) and not the ionic model variables
194 )
195 executed_file_dir = os.path.dirname(os.path.realpath(__file__))
196 problem_params["output_root"] = (
197 executed_file_dir + "/../../../../data/" + output_root
198 ) # output root folder. A hierarchy of folders is created in this folder, as root/domain_name/ref_+str(refinements)/ionic_model_name. Initial values are put here
199 problem_params["output_file_name"] = output_file_name
200 problem_params["ref_sol"] = ref_sol # reference solution file name
201 problem_params["end_time"] = end_time
202 Path(problem_params["output_root"]).mkdir(parents=True, exist_ok=True)
204 return problem_params
207def setup_and_run(
208 integrator,
209 num_nodes,
210 skip_residual_computation,
211 num_sweeps,
212 max_iter,
213 dt,
214 restol,
215 domain_name,
216 refinements,
217 order,
218 ionic_model_name,
219 read_init_val,
220 init_time,
221 enable_output,
222 write_as_reference_solution,
223 write_all_variables,
224 output_root,
225 output_file_name,
226 ref_sol,
227 end_time,
228 truly_time_parallel,
229 n_time_ranks,
230 finter,
231 write_database,
232):
234 # get time communicator
235 time_comm, time_rank = get_comms(n_time_ranks, truly_time_parallel)
237 # get time integration parameters
238 # set maximum number of iterations in ESDC/MLESDC/PFASST
239 step_params = get_step_params(maxiter=max_iter)
240 # set number of collocation nodes in each level
241 sweeper_params = get_sweeper_params(num_nodes=num_nodes, skip_residual_computation=skip_residual_computation)
242 # set step size, number of sweeps per iteration, and residual tolerance for the stopping criterion
243 level_params = get_level_params(
244 dt=dt,
245 nsweeps=num_sweeps,
246 restol=restol,
247 n_time_ranks=n_time_ranks,
248 )
250 # fix enable output to that only finest level has output
251 n_levels = max(len(refinements), len(num_nodes))
252 enable_output = [enable_output] + [False] * (n_levels - 1)
253 # get problem parameters
254 problem_params = get_problem_params(
255 domain_name=domain_name,
256 refinements=refinements,
257 ionic_model_name=ionic_model_name,
258 read_init_val=read_init_val,
259 init_time=init_time,
260 enable_output=enable_output,
261 end_time=end_time,
262 order=order,
263 output_root=output_root,
264 output_file_name=output_file_name,
265 ref_sol=ref_sol,
266 )
268 space_transfer_class = get_space_tranfer_params()
270 # get remaining prams
271 base_transfer_params = get_base_transfer_params(finter)
272 controller_params = get_controller_params(problem_params, n_time_ranks)
273 description = get_description(
274 integrator,
275 problem_params,
276 sweeper_params,
277 level_params,
278 step_params,
279 base_transfer_params,
280 space_transfer_class,
281 )
282 set_logger(controller_params)
283 controller = get_controller(controller_params, description, time_comm, n_time_ranks, truly_time_parallel)
285 # get PDE data
286 t0, Tend, uinit, P = get_P_data(controller, truly_time_parallel)
288 # print dofs stats
289 print_dofs_stats(time_rank, controller, P, uinit)
291 # run
292 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
294 # write reference solution, to be used later for error computation
295 if write_as_reference_solution:
296 P.write_reference_solution(uend, write_all_variables)
298 # compute errors, if a reference solution is found
299 error_availabe, error_L2, rel_error_L2 = P.compute_errors(uend)
301 # get some stats
302 iter_counts = get_sorted(stats, type="niter", sortby="time")
303 residuals = get_sorted(stats, type="residual_post_iteration", sortby="time")
304 if time_comm is not None:
305 iter_counts = time_comm.gather(iter_counts, root=0)
306 residuals = time_comm.gather(residuals, root=0)
307 if time_rank == 0:
308 iter_counts = [item for sublist in iter_counts for item in sublist]
309 residuals = [item for sublist in residuals for item in sublist]
310 iter_counts = time_comm.bcast(iter_counts, root=0)
311 residuals = time_comm.bcast(residuals, root=0)
313 iter_counts.sort()
314 times = [item[0] for item in iter_counts]
315 niters = [item[1] for item in iter_counts]
317 residuals.sort()
318 residuals_new = [residuals[0][1]]
319 t = residuals[0][0]
320 for i in range(1, len(residuals)):
321 if residuals[i][0] > t + dt / 2.0:
322 residuals_new.append(residuals[i][1])
323 t = residuals[i][0]
324 residuals = residuals_new
326 avg_niters = np.mean(niters)
327 if time_rank == 0:
328 controller.logger.info("Mean number of iterations: %4.2f" % avg_niters)
329 controller.logger.info(
330 "Std and var for number of iterations: %4.2f -- %4.2f" % (float(np.std(niters)), float(np.var(niters)))
331 )
333 if write_database and time_rank == 0:
334 errors = dict()
335 errors["error_L2"] = error_L2
336 errors["rel_error_L2"] = rel_error_L2
337 iters_info = dict()
338 iters_info["avg_niters"] = avg_niters
339 iters_info["times"] = times
340 iters_info["niters"] = niters
341 iters_info["residuals"] = residuals
342 file_name = P.output_folder / Path(P.output_file_name)
343 if file_name.with_suffix('.db').is_file():
344 os.remove(file_name.with_suffix('.db'))
345 data_man = database(file_name)
346 data_man.write_dictionary("errors", errors)
347 data_man.write_dictionary("iters_info", iters_info)
349 return error_L2, rel_error_L2, avg_niters, times, niters, residuals
352def main():
353 # define sweeper parameters
354 # integrator = "IMEXEXP"
355 integrator = "IMEXEXP_EXPRK"
356 num_nodes = [4]
357 num_sweeps = [1]
359 # set step parameters
360 max_iter = 100
362 # set level parameters
363 dt = 0.05
364 restol = 5e-8
366 # set problem parameters
367 domain_name = "cube_2D"
368 refinements = [-1]
369 order = 4 # 2 or 4
370 ionic_model_name = "TTP"
371 read_init_val = True
372 init_time = 3.0
373 enable_output = False
374 write_as_reference_solution = False
375 write_all_variables = False
376 write_database = False
377 end_time = 0.05
378 output_root = "results_tmp"
379 output_file_name = "ref_sol" if write_as_reference_solution else "monodomain"
380 ref_sol = "ref_sol"
381 skip_residual_computation = False
383 finter = False
385 # set time parallelism to True or emulated (False)
386 truly_time_parallel = False
387 n_time_ranks = 1
389 error_L2, rel_error_L2, avg_niters, times, niters, residuals = setup_and_run(
390 integrator,
391 num_nodes,
392 skip_residual_computation,
393 num_sweeps,
394 max_iter,
395 dt,
396 restol,
397 domain_name,
398 refinements,
399 order,
400 ionic_model_name,
401 read_init_val,
402 init_time,
403 enable_output,
404 write_as_reference_solution,
405 write_all_variables,
406 output_root,
407 output_file_name,
408 ref_sol,
409 end_time,
410 truly_time_parallel,
411 n_time_ranks,
412 finter,
413 write_database,
414 )
417if __name__ == "__main__":
418 main()