Coverage for pySDC/projects/FastWaveSlowWave/plot_dispersion.py: 97%
115 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import matplotlib
3matplotlib.use('Agg')
5import matplotlib.pyplot as plt
6import numpy as np
7import sympy
8from pylab import rcParams
10from pySDC.implementations.problem_classes.FastWaveSlowWave_0D import swfw_scalar
11from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
13from pySDC.implementations.problem_classes.acoustic_helpers.standard_integrators import dirk, rk_imex
15from pySDC.core.step import Step
18def findomega(stab_fh):
19 assert np.array_equal(np.shape(stab_fh), [2, 2]), 'Not 2x2 matrix...'
20 omega = sympy.Symbol('omega')
21 func = (sympy.exp(-1j * omega) - stab_fh[0, 0]) * (sympy.exp(-1j * omega) - stab_fh[1, 1]) - stab_fh[
22 0, 1
23 ] * stab_fh[1, 0]
24 solsym = sympy.solve(func, omega)
25 sol0 = complex(solsym[0])
26 sol1 = complex(solsym[1])
27 if sol0.real >= 0:
28 sol = sol0
29 elif sol1.real >= 0:
30 sol = sol1
31 else:
32 print("Two roots with real part of same sign...")
33 sol = sol0
34 return sol
37def compute_and_plot_dispersion(Nsamples=15, K=3):
38 """
39 Function to compute and plot the dispersion relation
41 Args:
42 Nsamples: number of samples for testing
43 K: number of iterations as well as order
44 """
45 problem_params = dict()
46 # SET VALUE FOR lambda_slow AND VALUES FOR lambda_fast ###
47 problem_params['lambda_s'] = np.array([0.0])
48 problem_params['lambda_f'] = np.array([0.0])
49 problem_params['u0'] = 1.0
51 # initialize sweeper parameters
52 sweeper_params = dict()
53 # SET TYPE AND NUMBER OF QUADRATURE NODES ###
54 sweeper_params['quad_type'] = 'RADAU-RIGHT'
55 sweeper_params['do_coll_update'] = True
56 sweeper_params['num_nodes'] = 3
58 # initialize level parameters
59 level_params = dict()
60 level_params['dt'] = 1.0
62 # fill description dictionary for easy step instantiation
63 description = dict()
64 description['problem_class'] = swfw_scalar # pass problem class
65 description['problem_params'] = problem_params # pass problem parameters
66 description['sweeper_class'] = imex_1st_order # pass sweeper
67 description['sweeper_params'] = sweeper_params # pass sweeper parameters
68 description['level_params'] = level_params # pass level parameters
69 description['step_params'] = dict() # pass step parameters
71 # ORDER OF DIRK/IMEX IS EQUAL TO NUMBER OF ITERATIONS AND THUS ORDER OF SDC ###
72 dirk_order = K
74 c_speed = 1.0
75 U_speed = 0.05
77 # now the description contains more or less everything we need to create a step
78 S = Step(description=description)
80 L = S.levels[0]
82 # u0 = S.levels[0].prob.u_exact(t0)
83 # S.init_step(u0)
84 QE = L.sweep.QE[1:, 1:]
85 QI = L.sweep.QI[1:, 1:]
86 Q = L.sweep.coll.Qmat[1:, 1:]
87 nnodes = L.sweep.coll.num_nodes
88 dt = L.params.dt
90 k_vec = np.linspace(0, np.pi, Nsamples + 1, endpoint=False)
91 k_vec = k_vec[1:]
92 phase = np.zeros((3, Nsamples))
93 amp_factor = np.zeros((3, Nsamples))
95 for i in range(0, np.size(k_vec)):
96 Cs = -1j * k_vec[i] * np.array([[0.0, c_speed], [c_speed, 0.0]], dtype='complex')
97 Uadv = -1j * k_vec[i] * np.array([[U_speed, 0.0], [0.0, U_speed]], dtype='complex')
99 LHS = np.eye(2 * nnodes) - dt * (np.kron(QI, Cs) + np.kron(QE, Uadv))
100 RHS = dt * (np.kron(Q, Uadv + Cs) - np.kron(QI, Cs) - np.kron(QE, Uadv))
102 LHSinv = np.linalg.inv(LHS)
103 Mat_sweep = np.linalg.matrix_power(LHSinv.dot(RHS), K)
104 for k in range(0, K):
105 Mat_sweep = Mat_sweep + np.linalg.matrix_power(LHSinv.dot(RHS), k).dot(LHSinv)
106 ##
107 # ---> The update formula for this case need verification!!
108 update = dt * np.kron(L.sweep.coll.weights, Uadv + Cs)
110 y1 = np.array([1, 0], dtype='complex')
111 y2 = np.array([0, 1], dtype='complex')
112 e1 = np.kron(np.ones(nnodes), y1)
113 stab_fh_1 = y1 + update.dot(Mat_sweep.dot(e1))
114 e2 = np.kron(np.ones(nnodes), y2)
115 stab_fh_2 = y2 + update.dot(Mat_sweep.dot(e2))
116 stab_sdc = np.column_stack((stab_fh_1, stab_fh_2))
118 # Stability function of backward Euler is 1/(1-z); system is y' = (Cs+Uadv)*y
119 # stab_ie = np.linalg.inv( np.eye(2) - step.status.dt*(Cs+Uadv) )
121 # For testing, insert exact stability function exp(-dt*i*k*(Cs+Uadv)
122 # stab_fh = la.expm(Cs+Uadv)
124 dirkts = dirk(Cs + Uadv, dirk_order)
125 stab_fh1 = dirkts.timestep(y1, 1.0)
126 stab_fh2 = dirkts.timestep(y2, 1.0)
127 stab_dirk = np.column_stack((stab_fh1, stab_fh2))
129 rkimex = rk_imex(M_fast=Cs, M_slow=Uadv, order=K)
130 stab_fh1 = rkimex.timestep(y1, 1.0)
131 stab_fh2 = rkimex.timestep(y2, 1.0)
132 stab_rk_imex = np.column_stack((stab_fh1, stab_fh2))
134 sol_sdc = findomega(stab_sdc)
135 sol_dirk = findomega(stab_dirk)
136 sol_rk_imex = findomega(stab_rk_imex)
138 # Now solve for discrete phase
139 phase[0, i] = sol_sdc.real / k_vec[i]
140 amp_factor[0, i] = np.exp(sol_sdc.imag)
141 phase[1, i] = sol_dirk.real / k_vec[i]
142 amp_factor[1, i] = np.exp(sol_dirk.imag)
143 phase[2, i] = sol_rk_imex.real / k_vec[i]
144 amp_factor[2, i] = np.exp(sol_rk_imex.imag)
146 rcParams['figure.figsize'] = 1.5, 1.5
147 fs = 8
148 fig = plt.figure()
149 plt.plot(k_vec, (U_speed + c_speed) + np.zeros(np.size(k_vec)), '--', color='k', linewidth=1.5, label='Exact')
150 plt.plot(k_vec, phase[1, :], '-', color='g', linewidth=1.5, label='DIRK(' + str(dirkts.order) + ')')
151 plt.plot(
152 k_vec,
153 phase[2, :],
154 '-+',
155 color='r',
156 linewidth=1.5,
157 label='IMEX(' + str(rkimex.order) + ')',
158 markevery=(2, 3),
159 mew=1.0,
160 )
161 plt.plot(
162 k_vec,
163 phase[0, :],
164 '-o',
165 color='b',
166 linewidth=1.5,
167 label='SDC(' + str(K) + ')',
168 markevery=(1, 3),
169 markersize=fs / 2,
170 )
171 plt.xlabel('Wave number', fontsize=fs, labelpad=0.25)
172 plt.ylabel('Phase speed', fontsize=fs, labelpad=0.5)
173 plt.xlim([k_vec[0], k_vec[-1:]])
174 plt.ylim([0.0, 1.1 * (U_speed + c_speed)])
175 fig.gca().tick_params(axis='both', labelsize=fs)
176 plt.legend(loc='lower left', fontsize=fs, prop={'size': fs - 2})
177 plt.xticks([0, 1, 2, 3], fontsize=fs)
178 filename = 'data/phase-K' + str(K) + '-M' + str(sweeper_params['num_nodes']) + '.png'
179 plt.gcf().savefig(filename, bbox_inches='tight')
181 fig = plt.figure()
182 plt.plot(k_vec, 1.0 + np.zeros(np.size(k_vec)), '--', color='k', linewidth=1.5, label='Exact')
183 plt.plot(k_vec, amp_factor[1, :], '-', color='g', linewidth=1.5, label='DIRK(' + str(dirkts.order) + ')')
184 plt.plot(
185 k_vec,
186 amp_factor[2, :],
187 '-+',
188 color='r',
189 linewidth=1.5,
190 label='IMEX(' + str(rkimex.order) + ')',
191 markevery=(2, 3),
192 mew=1.0,
193 )
194 plt.plot(
195 k_vec,
196 amp_factor[0, :],
197 '-o',
198 color='b',
199 linewidth=1.5,
200 label='SDC(' + str(K) + ')',
201 markevery=(1, 3),
202 markersize=fs / 2,
203 )
204 plt.xlabel('Wave number', fontsize=fs, labelpad=0.25)
205 plt.ylabel('Amplification factor', fontsize=fs, labelpad=0.5)
206 fig.gca().tick_params(axis='both', labelsize=fs)
207 plt.xlim([k_vec[0], k_vec[-1:]])
208 plt.ylim([k_vec[0], k_vec[-1:]])
209 plt.legend(loc='lower left', fontsize=fs, prop={'size': fs - 2})
210 plt.gca().set_ylim([0.0, 1.1])
211 plt.xticks([0, 1, 2, 3], fontsize=fs)
212 filename = 'data/ampfactor-K' + str(K) + '-M' + str(sweeper_params['num_nodes']) + '.png'
213 plt.gcf().savefig(filename, bbox_inches='tight')
216if __name__ == "__main__":
217 compute_and_plot_dispersion()