Coverage for pySDC/projects/parallelSDC/newton_vs_sdc.py: 100%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 09:13 +0000

1import os 

2import pickle 

3 

4import pySDC.helpers.plot_helper as plt_helper 

5from pySDC.helpers.stats_helper import get_sorted 

6 

7from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI 

8from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

9from pySDC.projects.parallelSDC.ErrReductionHook import err_reduction_hook 

10from pySDC.projects.parallelSDC.GeneralizedFisher_1D_FD_implicit_Jac import generalized_fisher_jac 

11from pySDC.projects.parallelSDC.linearized_implicit_fixed_parallel import linearized_implicit_fixed_parallel 

12from pySDC.projects.parallelSDC.linearized_implicit_fixed_parallel_prec import linearized_implicit_fixed_parallel_prec 

13 

14 

15def main(): 

16 # initialize level parameters 

17 level_params = dict() 

18 level_params['restol'] = 1e-12 

19 

20 # This comes as read-in for the step class (this is optional!) 

21 step_params = dict() 

22 step_params['maxiter'] = 20 

23 

24 # This comes as read-in for the problem class 

25 problem_params = dict() 

26 problem_params['nu'] = 1 

27 problem_params['nvars'] = 2047 

28 problem_params['lambda0'] = 5.0 

29 problem_params['newton_maxiter'] = 50 

30 problem_params['newton_tol'] = 1e-12 

31 problem_params['interval'] = (-5, 5) 

32 

33 # This comes as read-in for the sweeper class 

34 sweeper_params = dict() 

35 sweeper_params['quad_type'] = 'RADAU-RIGHT' 

36 sweeper_params['num_nodes'] = 5 

37 sweeper_params['QI'] = 'LU' 

38 sweeper_params['fixed_time_in_jacobian'] = 0 

39 

40 # initialize controller parameters 

41 controller_params = dict() 

42 controller_params['logger_level'] = 30 

43 controller_params['hook_class'] = err_reduction_hook 

44 

45 # Fill description dictionary for easy hierarchy creation 

46 description = dict() 

47 description['problem_class'] = generalized_fisher_jac 

48 description['problem_params'] = problem_params 

49 description['sweeper_params'] = sweeper_params 

50 description['step_params'] = step_params 

51 

52 # setup parameters "in time" 

53 t0 = 0 

54 Tend = 0.1 

55 

56 sweeper_list = [generic_implicit, linearized_implicit_fixed_parallel, linearized_implicit_fixed_parallel_prec] 

57 dt_list = [Tend / 2**i for i in range(1, 5)] 

58 

59 results = dict() 

60 results['sweeper_list'] = [sweeper.__name__ for sweeper in sweeper_list] 

61 results['dt_list'] = dt_list 

62 

63 # loop over the different sweepers and check results 

64 for sweeper in sweeper_list: 

65 description['sweeper_class'] = sweeper 

66 error_reduction = [] 

67 for dt in dt_list: 

68 print('Working with sweeper %s and dt = %s...' % (sweeper.__name__, dt)) 

69 

70 level_params['dt'] = dt 

71 description['level_params'] = level_params 

72 

73 # instantiate the controller 

74 controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) 

75 

76 # get initial values on finest level 

77 P = controller.MS[0].levels[0].prob 

78 uinit = P.u_exact(t0) 

79 

80 # call main function to get things done... 

81 uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) 

82 

83 # filter statistics 

84 error_pre = get_sorted(stats, type='error_pre_iteration', sortby='iter')[0][1] 

85 

86 error_post = get_sorted(stats, type='error_post_iteration', sortby='iter')[0][1] 

87 

88 error_reduction.append(error_post / error_pre) 

89 

90 print('error and reduction rate at time %s: %6.4e -- %6.4e' % (Tend, error_post, error_reduction[-1])) 

91 

92 results[sweeper.__name__] = error_reduction 

93 print() 

94 

95 file = open('data/error_reduction_data.pkl', 'wb') 

96 pickle.dump(results, file) 

97 file.close() 

98 

99 

100def plot_graphs(cwd=''): 

101 """ 

102 Helper function to plot graphs of initial and final values 

103 

104 Args: 

105 cwd (str): current working directory 

106 """ 

107 plt_helper.mpl.style.use('classic') 

108 

109 file = open(cwd + 'data/error_reduction_data.pkl', 'rb') 

110 results = pickle.load(file) 

111 

112 sweeper_list = results['sweeper_list'] 

113 dt_list = results['dt_list'] 

114 

115 color_list = ['red', 'blue', 'green'] 

116 marker_list = ['o', 's', 'd'] 

117 label_list = [] 

118 for sweeper in sweeper_list: 

119 if sweeper == 'generic_implicit': 

120 label_list.append('SDC') 

121 elif sweeper == 'linearized_implicit_fixed_parallel': 

122 label_list.append('Simplified Newton') 

123 elif sweeper == 'linearized_implicit_fixed_parallel_prec': 

124 label_list.append('Inexact Newton') 

125 

126 setups = zip(sweeper_list, color_list, marker_list, label_list) 

127 

128 plt_helper.setup_mpl() 

129 

130 plt_helper.newfig(textwidth=238.96, scale=0.89) 

131 

132 for sweeper, color, marker, label in setups: 

133 plt_helper.plt.loglog( 

134 dt_list, results[sweeper], lw=1, ls='-', color=color, marker=marker, markeredgecolor='k', label=label 

135 ) 

136 

137 plt_helper.plt.loglog(dt_list, [dt * 2 for dt in dt_list], lw=0.5, ls='--', color='k', label='linear') 

138 plt_helper.plt.loglog( 

139 dt_list, [dt * dt / dt_list[0] * 2 for dt in dt_list], lw=0.5, ls='-.', color='k', label='quadratic' 

140 ) 

141 

142 plt_helper.plt.xlabel('dt') 

143 plt_helper.plt.ylabel('error reduction') 

144 plt_helper.plt.grid() 

145 

146 # ax.set_xticks(dt_list, dt_list) 

147 plt_helper.plt.xticks(dt_list, dt_list) 

148 

149 plt_helper.plt.legend(loc=1, ncol=1) 

150 

151 plt_helper.plt.gca().invert_xaxis() 

152 plt_helper.plt.xlim([dt_list[0] * 1.1, dt_list[-1] / 1.1]) 

153 plt_helper.plt.ylim([4e-03, 1e0]) 

154 

155 # save plot, beautify 

156 fname = 'data/parallelSDC_fisher_newton' 

157 plt_helper.savefig(fname) 

158 

159 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file' 

160 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file' 

161 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file' 

162 

163 

164if __name__ == "__main__": 

165 # main() 

166 plot_graphs()