Coverage for pySDC/projects/FastWaveSlowWave/plot_stab_vs_k.py: 95%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1import matplotlib 

2 

3matplotlib.use('Agg') 

4 

5import numpy as np 

6from matplotlib import pyplot as plt 

7from pylab import rcParams 

8from matplotlib.ticker import ScalarFormatter 

9 

10from pySDC.implementations.problem_classes.FastWaveSlowWave_0D import swfw_scalar 

11from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

12 

13 

14from pySDC.core.step import Step 

15 

16 

17# noinspection PyShadowingNames 

18def compute_stab_vs_k(slow_resolved): 

19 """ 

20 Routine to compute modulus of the stability function 

21 

22 Args: 

23 slow_resolved (bool): switch to compute lambda_slow = 1 or lambda_slow = 4 

24 

25 Returns: 

26 numpy.ndarray: number of nodes 

27 numpy.ndarray: number of iterations 

28 numpy.ndarray: moduli 

29 """ 

30 

31 mvals = [2, 3, 4] 

32 kvals = np.arange(1, 10) 

33 lambda_fast = 10j 

34 

35 # PLOT EITHER FOR lambda_slow = 1 (resolved) OR lambda_slow = 4 (unresolved) 

36 if slow_resolved: 

37 lambda_slow = 1j 

38 else: 

39 lambda_slow = 4j 

40 stabval = np.zeros((np.size(mvals), np.size(kvals))) 

41 

42 problem_params = dict() 

43 # SET VALUE FOR lambda_slow AND VALUES FOR lambda_fast ### 

44 problem_params['lambda_s'] = np.array([0.0]) 

45 problem_params['lambda_f'] = np.array([0.0]) 

46 problem_params['u0'] = 1.0 

47 

48 # initialize sweeper parameters 

49 sweeper_params = dict() 

50 # SET TYPE AND NUMBER OF QUADRATURE NODES ### 

51 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

52 sweeper_params['do_coll_update'] = True 

53 

54 # initialize level parameters 

55 level_params = dict() 

56 level_params['dt'] = 1.0 

57 

58 # fill description dictionary for easy step instantiation 

59 description = dict() 

60 description['problem_class'] = swfw_scalar # pass problem class 

61 description['problem_params'] = problem_params # pass problem parameters 

62 description['sweeper_class'] = imex_1st_order # pass sweeper (see part B) 

63 description['level_params'] = level_params # pass level parameters 

64 description['step_params'] = dict() # pass step parameters 

65 

66 for i in range(0, np.size(mvals)): 

67 sweeper_params['num_nodes'] = mvals[i] 

68 description['sweeper_params'] = sweeper_params # pass sweeper parameters 

69 

70 # now the description contains more or less everything we need to create a step 

71 S = Step(description=description) 

72 

73 L = S.levels[0] 

74 

75 nnodes = L.sweep.coll.num_nodes 

76 

77 for k in range(0, np.size(kvals)): 

78 Kmax = kvals[k] 

79 Mat_sweep = L.sweep.get_scalar_problems_manysweep_mat(nsweeps=Kmax, lambdas=[lambda_fast, lambda_slow]) 

80 if L.sweep.params.do_coll_update: 

81 stab_fh = 1.0 + (lambda_fast + lambda_slow) * L.sweep.coll.weights.dot(Mat_sweep.dot(np.ones(nnodes))) 

82 else: 

83 q = np.zeros(nnodes) 

84 q[nnodes - 1] = 1.0 

85 stab_fh = q.dot(Mat_sweep.dot(np.ones(nnodes))) 

86 stabval[i, k] = np.absolute(stab_fh) 

87 

88 return mvals, kvals, stabval 

89 

90 

91# noinspection PyShadowingNames 

92def plot_stab_vs_k(slow_resolved, mvals, kvals, stabval): 

93 """ 

94 Plotting routine for moduli 

95 

96 Args: 

97 slow_resolved (bool): switch for lambda_slow 

98 mvals (numpy.ndarray): number of nodes 

99 kvals (numpy.ndarray): number of iterations 

100 stabval (numpy.ndarray): moduli 

101 """ 

102 

103 rcParams['figure.figsize'] = 2.5, 2.5 

104 fig = plt.figure() 

105 fs = 8 

106 plt.plot(kvals, stabval[0, :], 'o-', color='b', label=("M=%2i" % mvals[0]), markersize=fs - 2) 

107 plt.plot(kvals, stabval[1, :], 's-', color='r', label=("M=%2i" % mvals[1]), markersize=fs - 2) 

108 plt.plot(kvals, stabval[2, :], 'd-', color='g', label=("M=%2i" % mvals[2]), markersize=fs - 2) 

109 plt.plot(kvals, 1.0 + 0.0 * kvals, '--', color='k') 

110 plt.xlabel('Number of iterations K', fontsize=fs) 

111 plt.ylabel(r'Modulus of stability function $\left| R \right|$', fontsize=fs) 

112 plt.ylim([0.0, 1.2]) 

113 if slow_resolved: 

114 plt.legend(loc='upper right', fontsize=fs, prop={'size': fs}) 

115 else: 

116 plt.legend(loc='lower left', fontsize=fs, prop={'size': fs}) 

117 

118 plt.gca().get_xaxis().get_major_formatter().labelOnlyBase = False 

119 plt.gca().get_xaxis().set_major_formatter(ScalarFormatter()) 

120 # plt.show() 

121 if slow_resolved: 

122 filename = 'data/stab_vs_k_resolved.png' 

123 else: 

124 filename = 'data/stab_vs_k_unresolved.png' 

125 

126 fig.savefig(filename, bbox_inches='tight') 

127 

128 

129if __name__ == "__main__": 

130 mvals, kvals, stabval = compute_stab_vs_k(slow_resolved=True) 

131 print(np.amax(stabval)) 

132 plot_stab_vs_k(True, mvals, kvals, stabval) 

133 mvals, kvals, stabval = compute_stab_vs_k(slow_resolved=False) 

134 print(np.amax(stabval)) 

135 plot_stab_vs_k(False, mvals, kvals, stabval)