Coverage for pySDC/projects/TOMS/visualize_pySDC_with_PETSc.py: 99%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import os 

2 

3import matplotlib.colors as colors 

4import numpy as np 

5 

6import pySDC.helpers.plot_helper as plt_helper 

7 

8 

9def is_number(s): 

10 """ 

11 Helper function to detect numbers 

12 

13 Args: 

14 s: a string 

15 

16 Returns: 

17 bool: True if s is a number 

18 """ 

19 try: 

20 float(s) 

21 return True 

22 except ValueError: 

23 pass 

24 

25 try: 

26 import unicodedata 

27 

28 unicodedata.numeric(s) 

29 return True 

30 except (TypeError, ValueError): 

31 pass 

32 

33 return False 

34 

35 

36def join_timings(file=None, result=None, cwd=''): 

37 """ 

38 Helper function to read in JUBE result tables and convert/join them into a single dictionary 

39 

40 Args: 

41 file: current fils containing a JUBE result table 

42 result: dictionary (empty or not) 

43 cwd (str): current working directory 

44 

45 Returns: 

46 dict: result dictionary for further usage 

47 """ 

48 with open(cwd + file) as f: 

49 lines = f.readlines() 

50 

51 for line in lines: 

52 line_split = line.replace('\n', '').replace(' ', '').split('|') 

53 if is_number(line_split[0]): 

54 ntime = int(int(line_split[0]) * int(line_split[1]) / int(line_split[2])) 

55 nspace = int(line_split[2]) 

56 timing = float(line_split[3]) 

57 result[(nspace, ntime)] = timing 

58 

59 return result 

60 

61 

62def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): 

63 """ 

64 Helper function to crop a colormap 

65 

66 Args: 

67 cmap: colormap 

68 minval: minimum value 

69 maxval: maximum value 

70 n: stepsize 

71 

72 Returns: 

73 cropped colormap 

74 """ 

75 new_cmap = colors.LinearSegmentedColormap.from_list( 

76 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), cmap(np.linspace(minval, maxval, n)) 

77 ) 

78 return new_cmap 

79 

80 

81def visualize_matrix(result=None): 

82 """ 

83 Visualizes runtimes in a matrix (cores in space vs. cores in time) 

84 

85 Args: 

86 result: dictionary containing the runtimes 

87 """ 

88 process_list = [1, 2, 4, 6, 12, 24] 

89 dim = len(process_list) 

90 mat = np.zeros((dim, dim)) 

91 tmin = 1e03 

92 tmax = 0 

93 for key, item in result.items(): 

94 mat[process_list.index(key[0]), process_list.index(key[1])] = item 

95 tmin = min(tmin, item) 

96 tmax = max(tmax, item) 

97 

98 plt_helper.setup_mpl() 

99 plt_helper.newfig(textwidth=120, scale=1.5) 

100 cmap = plt_helper.plt.get_cmap('RdYlGn_r') 

101 new_cmap = truncate_colormap(cmap, 0.1, 0.9) 

102 plt_helper.plt.imshow( 

103 mat.T, origin='lower', norm=colors.LogNorm(vmin=tmin, vmax=tmax), cmap=new_cmap, aspect='auto' 

104 ) 

105 

106 for key, item in result.items(): 

107 timing = "{:3.1f}".format(item) 

108 plt_helper.plt.annotate( 

109 timing, 

110 xy=(process_list.index(key[0]), process_list.index(key[1])), 

111 size='x-small', 

112 ha='center', 

113 va='center', 

114 ) 

115 

116 plt_helper.plt.xticks(range(dim), process_list) 

117 plt_helper.plt.yticks(range(dim), process_list) 

118 plt_helper.plt.xlabel('Cores in space') 

119 plt_helper.plt.ylabel('Cores in time') 

120 

121 fname = 'data/runtimes_matrix_heat' 

122 plt_helper.savefig(fname) 

123 

124 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file' 

125 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file' 

126 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file' 

127 

128 

129def visualize_speedup(result=None): 

130 """ 

131 Visualizes runtimes of two different runs (MLSDC vs. PFASST) 

132 

133 Args: 

134 result: dictionary containing the runtimes 

135 """ 

136 process_list_MLSDC = [1, 2, 4, 6, 12, 24] 

137 process_list_PFASST = [24, 48, 96, 144, 288, 576] 

138 

139 timing_MLSDC = np.zeros(len(process_list_MLSDC)) 

140 timing_PFASST = np.zeros((len(process_list_PFASST))) 

141 for key, item in result.items(): 

142 if key[0] * key[1] in process_list_MLSDC: 

143 timing_MLSDC[process_list_MLSDC.index(key[0] * key[1])] = item 

144 if key[0] * key[1] in process_list_PFASST: 

145 timing_PFASST[process_list_PFASST.index(key[0] * key[1])] = item 

146 

147 plt_helper.setup_mpl() 

148 plt_helper.newfig(textwidth=120, scale=1.5) 

149 

150 process_list_all = process_list_MLSDC + process_list_PFASST 

151 ideal = [timing_MLSDC[0] / nproc for nproc in process_list_all] 

152 plt_helper.plt.loglog(process_list_all, ideal, 'k--', label='ideal') 

153 plt_helper.plt.loglog(process_list_MLSDC, timing_MLSDC, 'bo-', label='MLSDC') 

154 plt_helper.plt.loglog(process_list_PFASST, timing_PFASST, 'rs-', label='PFASST') 

155 

156 plt_helper.plt.xlim(process_list_all[0] / 2, process_list_all[-1] * 2) 

157 plt_helper.plt.ylim(ideal[-1] / 2, ideal[0] * 2) 

158 plt_helper.plt.xlabel('Number of cores') 

159 plt_helper.plt.ylabel('Runtime (sec.)') 

160 

161 plt_helper.plt.legend() 

162 plt_helper.plt.grid() 

163 

164 fname = 'data/speedup_heat' 

165 plt_helper.savefig(fname) 

166 assert os.path.isfile(fname + '.pdf'), 'ERROR: plotting did not create PDF file' 

167 # assert os.path.isfile(fname + '.pgf'), 'ERROR: plotting did not create PGF file' 

168 assert os.path.isfile(fname + '.png'), 'ERROR: plotting did not create PNG file' 

169 

170 

171def main(cwd=''): 

172 """ 

173 Main routine to call them all 

174 

175 Args: 

176 cwd (str): current working directory 

177 

178 """ 

179 result = {} 

180 files = [ 

181 'data/result_PFASST_1_NEW.dat', 

182 'data/result_PFASST_2_NEW.dat', 

183 'data/result_PFASST_4_NEW.dat', 

184 'data/result_PFASST_6_NEW.dat', 

185 'data/result_PFASST_12_NEW.dat', 

186 'data/result_PFASST_24_NEW.dat', 

187 ] 

188 for file in files: 

189 result = join_timings(file=file, result=result, cwd=cwd) 

190 visualize_matrix(result=result) 

191 

192 result = {} 

193 files = ['data/result_MLSDC_NEW.dat', 'data/result_PFASST_multinode_24_NEW.dat'] 

194 for file in files: 

195 result = join_timings(file=file, result=result, cwd=cwd) 

196 # result.pop((24, 24)) 

197 visualize_speedup(result=result) 

198 

199 

200if __name__ == "__main__": 

201 main()