Coverage for pySDC/projects/Monodomain/run_scripts/run_TestODE.py: 90%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1from pathlib import Path
2import numpy as np
3import logging
4import os
6from tqdm import tqdm
8from pySDC.core.errors import ParameterError
10from pySDC.projects.Monodomain.problem_classes.TestODE import MultiscaleTestODE
11from pySDC.projects.Monodomain.transfer_classes.TransferVectorOfDCTVectors import TransferVectorOfDCTVectors
13from pySDC.projects.Monodomain.hooks.HookClass_post_iter_info import post_iter_info_hook
15from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
17from pySDC.projects.Monodomain.sweeper_classes.exponential_runge_kutta.imexexp_1st_order import (
18 imexexp_1st_order as imexexp_1st_order_ExpRK,
19)
20from pySDC.projects.Monodomain.sweeper_classes.runge_kutta.imexexp_1st_order import imexexp_1st_order
22"""
23Run the multirate Dahlquist test equation and plot the stability domain of the method.
24We vary only the exponential term and the stiff term, while the non stiff term is kept constant (to allow 2D plots).
25"""
28def set_logger(controller_params):
29 logging.basicConfig(level=controller_params["logger_level"])
30 hooks_logger = logging.getLogger("hooks")
31 hooks_logger.setLevel(controller_params["logger_level"])
34def get_controller(controller_params, description, n_time_ranks):
35 controller = controller_nonMPI(num_procs=n_time_ranks, controller_params=controller_params, description=description)
36 return controller
39def get_P_data(controller):
40 P = controller.MS[0].levels[0].prob
41 # set time parameters
42 t0 = P.t0
43 Tend = P.Tend
44 uinit = P.initial_value()
45 return t0, Tend, uinit, P
48def get_base_transfer_params():
49 base_transfer_params = dict()
50 base_transfer_params["finter"] = False
51 return base_transfer_params
54def get_controller_params(output_root, logger_level):
55 controller_params = dict()
56 controller_params["predict_type"] = "pfasst_burnin"
57 controller_params["log_to_file"] = False
58 controller_params["fname"] = output_root + "controller"
59 controller_params["logger_level"] = logger_level
60 controller_params["dump_setup"] = False
61 controller_params["hook_class"] = [post_iter_info_hook]
62 return controller_params
65def get_description(
66 integrator, problem_params, sweeper_params, level_params, step_params, base_transfer_params, space_transfer_class
67):
68 description = dict()
70 problem = MultiscaleTestODE
72 if integrator == "IMEXEXP":
73 description["sweeper_class"] = imexexp_1st_order
74 elif integrator == "IMEXEXP_EXPRK":
75 description["sweeper_class"] = imexexp_1st_order_ExpRK
76 else:
77 raise ParameterError("Unknown integrator.")
79 description["problem_class"] = problem
80 description["problem_params"] = problem_params
81 description["sweeper_params"] = sweeper_params
82 description["level_params"] = level_params
83 description["step_params"] = step_params
84 description["base_transfer_params"] = base_transfer_params
85 description["space_transfer_class"] = space_transfer_class
86 return description
89def get_step_params(maxiter):
90 step_params = dict()
91 step_params["maxiter"] = maxiter
92 return step_params
95def get_level_params(dt, nsweeps, restol):
96 # initialize level parameters
97 level_params = dict()
98 level_params["restol"] = restol
99 level_params["dt"] = dt
100 level_params["nsweeps"] = nsweeps
101 level_params["residual_type"] = "full_rel"
102 return level_params
105def get_sweeper_params(num_nodes):
106 # initialize sweeper parameters
107 sweeper_params = dict()
108 sweeper_params["initial_guess"] = "spread"
109 sweeper_params["quad_type"] = "RADAU-RIGHT"
110 sweeper_params["num_nodes"] = num_nodes
111 sweeper_params["QI"] = "IE"
113 return sweeper_params
116def get_output_root():
117 executed_file_dir = os.path.dirname(os.path.realpath(__file__))
118 output_root = executed_file_dir + "/../../../../data/Monodomain/results_tmp"
119 return output_root
122def get_problem_params(lmbda_laplacian, lmbda_gating, lmbda_others, end_time):
123 # initialize problem parameters
124 problem_params = dict()
125 problem_params["output_file_name"] = "monodomain"
126 problem_params["output_root"] = get_output_root()
127 problem_params["end_time"] = end_time
128 problem_params["lmbda_laplacian"] = lmbda_laplacian
129 problem_params["lmbda_gating"] = lmbda_gating
130 problem_params["lmbda_others"] = lmbda_others
131 Path(problem_params["output_root"]).mkdir(parents=True, exist_ok=True)
132 return problem_params
135def plot_stability_domain(lmbda_laplacian_list, lmbda_gating_list, R, integrator, num_nodes, n_time_ranks):
136 import matplotlib.pyplot as plt
137 from matplotlib.colors import LogNorm
138 import pySDC.helpers.plot_helper as plt_helper
140 plt_helper.setup_mpl()
142 # fig, ax = plt_helper.newfig(textwidth=400, scale=0.89, ratio=0.5)
143 # fig, ax = plt_helper.newfig(textwidth=238.96, scale=0.89)
144 fig, ax = plt_helper.plt.subplots(
145 figsize=plt_helper.figsize(textwidth=400, scale=1.0, ratio=0.78), layout='constrained'
146 )
148 fs_label = 14
149 fs_ticks = 12
150 fs_title = 16
151 X, Y = np.meshgrid(lmbda_gating_list, lmbda_laplacian_list)
152 R = np.abs(R)
153 CS = ax.contourf(X, Y, R, cmap=plt.cm.viridis, levels=np.logspace(-6, 0, 13), norm=LogNorm())
154 ax.plot(lmbda_gating_list, 0 * lmbda_gating_list, 'k--', linewidth=1.0)
155 ax.plot(0 * lmbda_laplacian_list, lmbda_laplacian_list, 'k--', linewidth=1.0)
156 ax.contour(CS, levels=CS.levels, colors='black')
157 ax.set_xlabel(r'$z_{e}$', fontsize=fs_label, labelpad=-5)
158 ax.set_ylabel(r'$z_{I}$', fontsize=fs_label, labelpad=-10)
159 ax.tick_params(axis='x', labelsize=fs_ticks)
160 ax.tick_params(axis='y', labelsize=fs_ticks)
161 if len(num_nodes) == 1 and n_time_ranks == 1:
162 prefix = ""
163 elif len(num_nodes) > 1 and n_time_ranks == 1:
164 prefix = "ML"
165 elif len(num_nodes) > 1 and n_time_ranks > 1:
166 prefix = "PFASST "
167 if integrator == "IMEXEXP":
168 ax.set_title(prefix + "SDC stability domain", fontsize=fs_title)
169 elif integrator == "IMEXEXP_EXPRK":
170 ax.set_title(prefix + "ESDC stability domain", fontsize=fs_title)
171 ax.yaxis.tick_right()
172 ax.yaxis.set_label_position("right")
173 cbar = fig.colorbar(CS)
174 cbar.ax.set_ylabel(r'$|R(z_e,z_{I})|$', fontsize=fs_label, labelpad=-20)
175 cbar.set_ticks([cbar.vmin, cbar.vmax]) # keep only the ticks at the ends
176 cbar.ax.tick_params(labelsize=fs_ticks)
177 # plt_helper.plt.show()
178 plt_helper.savefig("data/stability_domain_" + integrator, save_pdf=False, save_pgf=False, save_png=True)
181def main(integrator, dl, l_min, openmp, n_time_ranks, end_time, num_nodes, check_stability):
183 # get time integration parameters
184 # set maximum number of iterations in SDC/ESDC/MLSDC/etc
185 step_params = get_step_params(maxiter=5)
186 # set number of collocation nodes in each level
187 sweeper_params = get_sweeper_params(num_nodes=num_nodes)
188 # set step size, number of sweeps per iteration, and residual tolerance for the stopping criterion
189 level_params = get_level_params(dt=1.0, nsweeps=[1], restol=5e-8)
190 # set space transfer parameters
191 # space_transfer_class = Transfer_myfloat
192 space_transfer_class = TransferVectorOfDCTVectors
193 base_transfer_params = get_base_transfer_params()
194 controller_params = get_controller_params(get_output_root(), logger_level=40)
196 # set stability test parameters
197 lmbda_others = -1.0 # the non stiff term
198 lmbda_laplacian_min = l_min # the stiff term
199 lmbda_laplacian_max = 0.0
200 lmbda_gating_min = l_min # the exponential term
201 lmbda_gating_max = 0.0
203 # define the grid for the stability domain
204 n_lmbda_laplacian = np.round((lmbda_laplacian_max - lmbda_laplacian_min) / dl).astype(int) + 1
205 n_lmbda_gating = np.round((lmbda_gating_max - lmbda_gating_min) / dl).astype(int) + 1
206 lmbda_laplacian_list = np.linspace(lmbda_laplacian_min, lmbda_laplacian_max, n_lmbda_laplacian)
207 lmbda_gating_list = np.linspace(lmbda_gating_min, lmbda_gating_max, n_lmbda_gating)
209 if not openmp:
210 R = np.zeros((n_lmbda_laplacian, n_lmbda_gating))
211 for i in tqdm(range(n_lmbda_gating)):
212 for j in range(n_lmbda_laplacian):
213 lmbda_gating = lmbda_gating_list[i]
214 lmbda_laplacian = lmbda_laplacian_list[j]
216 problem_params = get_problem_params(
217 lmbda_laplacian=lmbda_laplacian,
218 lmbda_gating=lmbda_gating,
219 lmbda_others=lmbda_others,
220 end_time=end_time,
221 )
222 description = get_description(
223 integrator,
224 problem_params,
225 sweeper_params,
226 level_params,
227 step_params,
228 base_transfer_params,
229 space_transfer_class,
230 )
231 set_logger(controller_params)
232 controller = get_controller(controller_params, description, n_time_ranks)
234 t0, Tend, uinit, P = get_P_data(controller)
235 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
237 R[j, i] = abs(uend)
238 else:
239 import pymp
241 R = pymp.shared.array((n_lmbda_laplacian, n_lmbda_gating), dtype=float)
242 with pymp.Parallel(12) as p:
243 for i in tqdm(p.range(0, n_lmbda_gating)):
244 for j in range(n_lmbda_laplacian):
245 lmbda_gating = lmbda_gating_list[i]
246 lmbda_laplacian = lmbda_laplacian_list[j]
248 problem_params = get_problem_params(
249 lmbda_laplacian=lmbda_laplacian,
250 lmbda_gating=lmbda_gating,
251 lmbda_others=lmbda_others,
252 end_time=end_time,
253 )
254 description = get_description(
255 integrator,
256 problem_params,
257 sweeper_params,
258 level_params,
259 step_params,
260 base_transfer_params,
261 space_transfer_class,
262 )
263 set_logger(controller_params)
264 controller = get_controller(controller_params, description, n_time_ranks)
266 t0, Tend, uinit, P = get_P_data(controller)
267 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
269 R[j, i] = abs(uend)
271 plot_stability_domain(lmbda_laplacian_list, lmbda_gating_list, R, integrator, num_nodes, n_time_ranks)
273 if check_stability:
274 assert (
275 np.max(np.abs(R.ravel())) <= 1.0
276 ), "The maximum absolute value of the stability function is greater than 1.0."
279if __name__ == "__main__":
280 # Plot stability for exponential SDC coupled with the implicit-explicit-exponential integrator as preconditioner
281 main(
282 integrator="IMEXEXP_EXPRK",
283 dl=2,
284 l_min=-100,
285 openmp=True,
286 n_time_ranks=1,
287 end_time=1.0,
288 num_nodes=[5, 3],
289 check_stability=True, # check that the stability function is bounded by 1.0
290 )
291 # Plot stability for standard SDC coupled with the implicit-explicit-exponential integrator as preconditioner
292 main(
293 integrator="IMEXEXP",
294 dl=2,
295 l_min=-100,
296 openmp=True,
297 n_time_ranks=1,
298 end_time=1.0,
299 num_nodes=[5, 3],
300 check_stability=False, # do not check for stability since we already know that the method is not stable
301 )