Coverage for pySDC/projects/FastWaveSlowWave/plot_stab_vs_k.py: 95%
62 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1import matplotlib
3matplotlib.use('Agg')
5import numpy as np
6from matplotlib import pyplot as plt
7from pylab import rcParams
8from matplotlib.ticker import ScalarFormatter
10from pySDC.implementations.problem_classes.FastWaveSlowWave_0D import swfw_scalar
11from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
14from pySDC.core.step import Step
17# noinspection PyShadowingNames
18def compute_stab_vs_k(slow_resolved):
19 """
20 Routine to compute modulus of the stability function
22 Args:
23 slow_resolved (bool): switch to compute lambda_slow = 1 or lambda_slow = 4
25 Returns:
26 numpy.ndarray: number of nodes
27 numpy.ndarray: number of iterations
28 numpy.ndarray: moduli
29 """
31 mvals = [2, 3, 4]
32 kvals = np.arange(1, 10)
33 lambda_fast = 10j
35 # PLOT EITHER FOR lambda_slow = 1 (resolved) OR lambda_slow = 4 (unresolved)
36 if slow_resolved:
37 lambda_slow = 1j
38 else:
39 lambda_slow = 4j
40 stabval = np.zeros((np.size(mvals), np.size(kvals)))
42 problem_params = dict()
43 # SET VALUE FOR lambda_slow AND VALUES FOR lambda_fast ###
44 problem_params['lambda_s'] = np.array([0.0])
45 problem_params['lambda_f'] = np.array([0.0])
46 problem_params['u0'] = 1.0
48 # initialize sweeper parameters
49 sweeper_params = dict()
50 # SET TYPE AND NUMBER OF QUADRATURE NODES ###
51 sweeper_params['quad_type'] = 'RADAU-RIGHT'
52 sweeper_params['do_coll_update'] = True
54 # initialize level parameters
55 level_params = dict()
56 level_params['dt'] = 1.0
58 # fill description dictionary for easy step instantiation
59 description = dict()
60 description['problem_class'] = swfw_scalar # pass problem class
61 description['problem_params'] = problem_params # pass problem parameters
62 description['sweeper_class'] = imex_1st_order # pass sweeper (see part B)
63 description['level_params'] = level_params # pass level parameters
64 description['step_params'] = dict() # pass step parameters
66 for i in range(0, np.size(mvals)):
67 sweeper_params['num_nodes'] = mvals[i]
68 description['sweeper_params'] = sweeper_params # pass sweeper parameters
70 # now the description contains more or less everything we need to create a step
71 S = Step(description=description)
73 L = S.levels[0]
75 nnodes = L.sweep.coll.num_nodes
77 for k in range(0, np.size(kvals)):
78 Kmax = kvals[k]
79 Mat_sweep = L.sweep.get_scalar_problems_manysweep_mat(nsweeps=Kmax, lambdas=[lambda_fast, lambda_slow])
80 if L.sweep.params.do_coll_update:
81 stab_fh = 1.0 + (lambda_fast + lambda_slow) * L.sweep.coll.weights.dot(Mat_sweep.dot(np.ones(nnodes)))
82 else:
83 q = np.zeros(nnodes)
84 q[nnodes - 1] = 1.0
85 stab_fh = q.dot(Mat_sweep.dot(np.ones(nnodes)))
86 stabval[i, k] = np.absolute(stab_fh)
88 return mvals, kvals, stabval
91# noinspection PyShadowingNames
92def plot_stab_vs_k(slow_resolved, mvals, kvals, stabval):
93 """
94 Plotting routine for moduli
96 Args:
97 slow_resolved (bool): switch for lambda_slow
98 mvals (numpy.ndarray): number of nodes
99 kvals (numpy.ndarray): number of iterations
100 stabval (numpy.ndarray): moduli
101 """
103 rcParams['figure.figsize'] = 2.5, 2.5
104 fig = plt.figure()
105 fs = 8
106 plt.plot(kvals, stabval[0, :], 'o-', color='b', label=("M=%2i" % mvals[0]), markersize=fs - 2)
107 plt.plot(kvals, stabval[1, :], 's-', color='r', label=("M=%2i" % mvals[1]), markersize=fs - 2)
108 plt.plot(kvals, stabval[2, :], 'd-', color='g', label=("M=%2i" % mvals[2]), markersize=fs - 2)
109 plt.plot(kvals, 1.0 + 0.0 * kvals, '--', color='k')
110 plt.xlabel('Number of iterations K', fontsize=fs)
111 plt.ylabel(r'Modulus of stability function $\left| R \right|$', fontsize=fs)
112 plt.ylim([0.0, 1.2])
113 if slow_resolved:
114 plt.legend(loc='upper right', fontsize=fs, prop={'size': fs})
115 else:
116 plt.legend(loc='lower left', fontsize=fs, prop={'size': fs})
118 plt.gca().get_xaxis().get_major_formatter().labelOnlyBase = False
119 plt.gca().get_xaxis().set_major_formatter(ScalarFormatter())
120 # plt.show()
121 if slow_resolved:
122 filename = 'data/stab_vs_k_resolved.png'
123 else:
124 filename = 'data/stab_vs_k_unresolved.png'
126 fig.savefig(filename, bbox_inches='tight')
129if __name__ == "__main__":
130 mvals, kvals, stabval = compute_stab_vs_k(slow_resolved=True)
131 print(np.amax(stabval))
132 plot_stab_vs_k(True, mvals, kvals, stabval)
133 mvals, kvals, stabval = compute_stab_vs_k(slow_resolved=False)
134 print(np.amax(stabval))
135 plot_stab_vs_k(False, mvals, kvals, stabval)