Coverage for pySDC/projects/parallelSDC/nonlinear_playground.py: 100%

106 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import os 

2import pickle 

3 

4import numpy as np 

5 

6import pySDC.helpers.plot_helper as plt_helper 

7from pySDC.helpers.stats_helper import get_sorted 

8 

9from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

10from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

11from pySDC.projects.parallelSDC.GeneralizedFisher_1D_FD_implicit_Jac import generalized_fisher_jac 

12from pySDC.projects.parallelSDC.linearized_implicit_fixed_parallel import linearized_implicit_fixed_parallel 

13from pySDC.projects.parallelSDC.linearized_implicit_fixed_parallel_prec import linearized_implicit_fixed_parallel_prec 

14from pySDC.projects.parallelSDC.linearized_implicit_parallel import linearized_implicit_parallel 

15 

16 

17def main(): 

18 # initialize level parameters 

19 level_params = dict() 

20 level_params['restol'] = 1e-10 

21 level_params['dt'] = 0.01 

22 

23 # This comes as read-in for the step class (this is optional!) 

24 step_params = dict() 

25 step_params['maxiter'] = 50 

26 

27 # This comes as read-in for the problem class 

28 problem_params = dict() 

29 problem_params['nu'] = 1 

30 problem_params['nvars'] = 255 

31 problem_params['lambda0'] = 5.0 

32 problem_params['newton_maxiter'] = 50 

33 problem_params['newton_tol'] = 1e-12 

34 problem_params['interval'] = (-5, 5) 

35 

36 # This comes as read-in for the sweeper class 

37 sweeper_params = dict() 

38 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

39 sweeper_params['num_nodes'] = 5 

40 sweeper_params['QI'] = 'LU' 

41 sweeper_params['fixed_time_in_jacobian'] = 0 

42 

43 # initialize controller parameters 

44 controller_params = dict() 

45 controller_params['logger_level'] = 30 

46 

47 # Fill description dictionary for easy hierarchy creation 

48 description = dict() 

49 description['problem_class'] = generalized_fisher_jac 

50 description['problem_params'] = problem_params 

51 description['sweeper_params'] = sweeper_params 

52 description['level_params'] = level_params 

53 description['step_params'] = step_params 

54 

55 sweeper_list = [ 

56 generic_implicit, 

57 linearized_implicit_fixed_parallel_prec, 

58 linearized_implicit_fixed_parallel, 

59 linearized_implicit_parallel, 

60 ] 

61 

62 f = open('data/parallelSDC_nonlinear_out.txt', 'w') 

63 uinit = None 

64 uex = None 

65 uend = None 

66 P = None 

67 

68 # loop over the different sweepers and check results 

69 for sweeper in sweeper_list: 

70 description['sweeper_class'] = sweeper 

71 

72 # instantiate the controller 

73 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) 

74 

75 # setup parameters "in time" 

76 t0 = 0 

77 Tend = 0.1 

78 

79 # get initial values on finest level 

80 P = controller.MS[0].levels[0].prob 

81 uinit = P.u_exact(t0) 

82 

83 # call main function to get things done... 

84 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

85 

86 # compute exact solution and compare 

87 uex = P.u_exact(Tend) 

88 err = abs(uex - uend) 

89 

90 print('error at time %s: %s' % (Tend, err)) 

91 

92 # filter statistics by type (number of iterations) 

93 iter_counts = get_sorted(stats, type='niter', sortby='time') 

94 

95 # compute and print statistics 

96 niters = np.array([item[1] for item in iter_counts]) 

97 out = ' Mean number of iterations: %4.2f' % np.mean(niters) 

98 f.write(out + '\n') 

99 print(out) 

100 out = ' Range of values for number of iterations: %2i ' % np.ptp(niters) 

101 f.write(out + '\n') 

102 print(out) 

103 out = ' Position of max/min number of iterations: %2i -- %2i' % ( 

104 int(np.argmax(niters)), 

105 int(np.argmin(niters)), 

106 ) 

107 f.write(out + '\n') 

108 print(out) 

109 out = ' Std and var for number of iterations: %4.2f -- %4.2f' % (float(np.std(niters)), float(np.var(niters))) 

110 f.write(out + '\n') 

111 f.write(out + '\n') 

112 print(out) 

113 

114 f.write('\n') 

115 print() 

116 

117 assert err < 3.686e-05, 'ERROR: error is too high for sweeper %s, got %s' % (sweeper.__name__, err) 

118 assert ( 

119 np.mean(niters) == 7.5 or np.mean(niters) == 4.0 

120 ), 'ERROR: mean number of iterations not as expected, got %s' % np.mean(niters) 

121 

122 f.close() 

123 

124 results = dict() 

125 results['interval'] = problem_params['interval'] 

126 results['xvalues'] = np.array([(i + 1 - (P.nvars + 1) / 2) * P.dx for i in range(P.nvars)]) 

127 results['uinit'] = uinit 

128 results['uend'] = uend 

129 results['uex'] = uex 

130 

131 # write out for later visualization 

132 file = open('data/parallelSDC_results_graphs.pkl', 'wb') 

133 pickle.dump(results, file) 

134 

135 assert os.path.isfile('data/parallelSDC_results_graphs.pkl'), 'ERROR: pickle did not create file' 

136 

137 

138def plot_graphs(): 

139 """ 

140 Helper function to plot graphs of initial and final values 

141 """ 

142 

143 file = open('data/parallelSDC_results_graphs.pkl', 'rb') 

144 results = pickle.load(file) 

145 

146 interval = results['interval'] 

147 xvalues = results['xvalues'] 

148 uinit = results['uinit'] 

149 uend = results['uend'] 

150 uex = results['uex'] 

151 

152 plt_helper.setup_mpl() 

153 

154 # set up figure 

155 plt_helper.newfig(textwidth=338.0, scale=1.0) 

156 

157 plt_helper.plt.xlabel('x') 

158 plt_helper.plt.ylabel('f(x)') 

159 plt_helper.plt.xlim((interval[0] - 0.01, interval[1] + 0.01)) 

160 plt_helper.plt.ylim((-0.1, 1.1)) 

161 plt_helper.plt.grid() 

162 

163 # plot 

164 plt_helper.plt.plot(xvalues, uinit, 'r--', lw=1, label='initial') 

165 plt_helper.plt.plot(xvalues, uend, 'bs', lw=1, markeredgecolor='k', label='computed') 

166 plt_helper.plt.plot(xvalues, uex, 'g-', lw=1, label='exact') 

167 

168 plt_helper.plt.legend(loc=2, ncol=1) 

169 

170 # save plot as PDF, beautify 

171 fname = 'data/parallelSDC_fisher' 

172 plt_helper.savefig(fname) 

173 

174 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file' 

175 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file' 

176 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file' 

177 

178 

179if __name__ == "__main__": 

180 # main() 

181 plot_graphs()