Coverage for pySDC / projects / GPU / configs / base_config.py: 71%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 07:06 +0000

1from pySDC.core.convergence_controller import ConvergenceController 

2import pickle 

3import numpy as np 

4 

5 

6def get_config(args): 

7 name = args['config'] 

8 if name[:2] == 'GS': 

9 from pySDC.projects.GPU.configs.GS_configs import get_config as _get_config 

10 elif name[:5] == 'RBC3D': 

11 from pySDC.projects.RayleighBenard.RBC3D_configs import get_config as _get_config 

12 elif name[:3] == 'RBC': 

13 from pySDC.projects.GPU.configs.RBC_configs import get_config as _get_config 

14 else: 

15 raise NotImplementedError(f'There is no configuration called {name!r}!') 

16 

17 return _get_config(args) 

18 

19 

20def get_comms(n_procs_list, comm_world=None, _comm=None, _tot_rank=0, _rank=None, useGPU=False): 

21 from mpi4py import MPI 

22 

23 comm_world = MPI.COMM_WORLD if comm_world is None else comm_world 

24 _comm = comm_world if _comm is None else _comm 

25 _rank = comm_world.rank if _rank is None else _rank 

26 

27 if len(n_procs_list) > 0: 

28 color = _tot_rank + _rank // n_procs_list[0] 

29 new_comm = comm_world.Split(color) 

30 

31 assert new_comm.size == n_procs_list[0] 

32 

33 if useGPU: 

34 import cupy_backends 

35 

36 try: 

37 import cupy 

38 from pySDC.helpers.NCCL_communicator import NCCLComm 

39 

40 new_comm = NCCLComm(new_comm) 

41 except ( 

42 ImportError, 

43 cupy_backends.cuda.api.runtime.CUDARuntimeError, 

44 cupy_backends.cuda.libs.nccl.NcclError, 

45 ): 

46 print('Warning: Communicator is MPI instead of NCCL in spite of using GPUs!') 

47 

48 return [new_comm] + get_comms( 

49 n_procs_list[1:], 

50 comm_world, 

51 _comm=new_comm, 

52 _tot_rank=_tot_rank + _comm.size * new_comm.rank, 

53 _rank=_comm.rank // new_comm.size, 

54 useGPU=useGPU, 

55 ) 

56 else: 

57 return [] 

58 

59 

60class Config(object): 

61 sweeper_type = None 

62 Tend = None 

63 base_path = './' 

64 logging_time_increment = 0.5 

65 

66 def __init__(self, args, comm_world=None): 

67 from mpi4py import MPI 

68 

69 self.args = args 

70 self.comm_world = MPI.COMM_WORLD if comm_world is None else comm_world 

71 self.n_procs_list = args["procs"] 

72 if args['mode'] in ['run', 'benchmark']: 

73 distribution = args.get('distribution', 'time_first') 

74 if distribution in ['space_first', 'space_major']: 

75 self.comms = get_comms( 

76 n_procs_list=self.n_procs_list[::-1], useGPU=args['useGPU'], comm_world=self.comm_world 

77 )[::-1] 

78 elif distribution in ['time_first', 'time_major']: 

79 self.comms = get_comms( 

80 n_procs_list=self.n_procs_list, useGPU=args['useGPU'], comm_world=self.comm_world 

81 ) 

82 else: 

83 raise NotImplementedError 

84 else: 

85 self.comms = [MPI.COMM_SELF, MPI.COMM_SELF, MPI.COMM_SELF] 

86 self.ranks = [me.rank for me in self.comms] 

87 

88 def get_file_name(self): 

89 res = self.args['res'] 

90 return f'{self.base_path}/data/{type(self).__name__}-res{res}.pySDC' 

91 

92 def get_LogToFile(self, *args, **kwargs): 

93 if self.comms[1].rank > 0: 

94 return None 

95 from pySDC.implementations.hooks.log_solution import LogToFile 

96 

97 LogToFile.filename = self.get_file_name() 

98 LogToFile.time_increment = self.logging_time_increment 

99 LogToFile.allow_overwriting = True 

100 

101 return LogToFile 

102 

103 def get_description(self, *args, MPIsweeper=False, useGPU=False, **kwargs): 

104 description = {} 

105 description['problem_class'] = None 

106 description['problem_params'] = {'useGPU': useGPU, 'comm': self.comms[2]} 

107 description['sweeper_class'] = self.get_sweeper(useMPI=MPIsweeper) 

108 description['sweeper_params'] = {'initial_guess': 'copy'} 

109 description['level_params'] = {} 

110 description['step_params'] = {} 

111 description['convergence_controllers'] = {} 

112 

113 if self.get_LogToFile(): 

114 path = self.get_file_name()[:-6] 

115 description['convergence_controllers'][LogStats] = {'path': path} 

116 

117 if MPIsweeper: 

118 description['sweeper_params']['comm'] = self.comms[1] 

119 return description 

120 

121 def get_controller_params(self, *args, logger_level=15, **kwargs): 

122 from pySDC.implementations.hooks.log_work import LogWork 

123 from pySDC.implementations.hooks.log_step_size import LogStepSize 

124 from pySDC.implementations.hooks.log_restarts import LogRestarts 

125 

126 controller_params = {} 

127 controller_params['logger_level'] = logger_level if self.comm_world.rank == 0 else 40 

128 controller_params['hook_class'] = [LogWork, LogStepSize, LogRestarts] 

129 logToFile = self.get_LogToFile() 

130 if logToFile: 

131 controller_params['hook_class'] += [logToFile] 

132 controller_params['mssdc_jac'] = False 

133 return controller_params 

134 

135 def get_sweeper(self, useMPI): 

136 if useMPI and self.sweeper_type == 'IMEX': 

137 from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI as sweeper 

138 elif not useMPI and self.sweeper_type == 'IMEX': 

139 from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper 

140 elif useMPI and self.sweeper_type == 'generic_implicit': 

141 from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper 

142 elif not useMPI and self.sweeper_type == 'generic_implicit': 

143 from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper 

144 else: 

145 raise NotImplementedError(f'Don\'t know the sweeper for {self.sweeper_type=}') 

146 

147 return sweeper 

148 

149 def prepare_caches(self, prob): 

150 pass 

151 

152 def get_path(self, *args, ranks=None, **kwargs): 

153 ranks = self.ranks if ranks is None else ranks 

154 return f'{type(self).__name__}{self.args_to_str()}-{ranks[0]}-{ranks[2]}' 

155 

156 def args_to_str(self, args=None): 

157 args = self.args if args is None else args 

158 name = '' 

159 

160 name = f'{name}-res_{args["res"]}' 

161 name = f'{name}-useGPU_{args["useGPU"]}' 

162 name = f'{name}-procs_{args["procs"][0]}_{args["procs"][1]}_{args["procs"][2]}' 

163 return name 

164 

165 def plot(self, P, idx, num_procs_list): 

166 raise NotImplementedError 

167 

168 def get_initial_condition(self, P, *args, restart_idx=0, **kwargs): 

169 if restart_idx == 0: 

170 return P.u_exact(t=0), 0 

171 else: 

172 

173 from pySDC.helpers.fieldsIO import FieldsIO 

174 

175 P.setUpFieldsIO() 

176 outfile = FieldsIO.fromFile(self.get_file_name()) 

177 

178 t0, solution = outfile.readField(restart_idx) 

179 solution = solution[: P.spectral.ncomponents, ...] 

180 

181 u0 = P.u_init 

182 

183 if P.spectral_space: 

184 u0[...] = P.transform(solution) 

185 else: 

186 u0[...] = solution 

187 

188 return u0, t0 

189 

190 LogToFile = self.get_LogToFile() 

191 file = LogToFile.load(restart_idx) 

192 LogToFile.counter = restart_idx 

193 u0 = P.u_init 

194 if hasattr(P, 'spectral_space'): 

195 if P.spectral_space: 

196 u0[...] = P.transform(file['u']) 

197 else: 

198 u0[...] = file['u'] 

199 else: 

200 u0[...] = file['u'] 

201 return u0, file['t'] 

202 

203 

204class LogStats(ConvergenceController): 

205 

206 def get_stats_path(self, index=0): 

207 return f'{self.params.path}_{index:06d}-stats.pickle' 

208 

209 def merge_all_stats(self, controller): 

210 hook = self.params.hook 

211 

212 stats = {} 

213 for i in range(hook.counter - 1): 

214 try: 

215 with open(self.get_stats_path(index=i), 'rb') as file: 

216 _stats = pickle.load(file) 

217 stats = {**stats, **_stats} 

218 except (FileNotFoundError, EOFError): 

219 print(f'Warning: No stats found at path {self.get_stats_path(index=i)}') 

220 

221 stats = {**stats, **controller.return_stats()} 

222 return stats 

223 

224 def reset_stats(self, controller): 

225 for hook in controller.hooks: 

226 hook.reset_stats() 

227 self.logger.debug('Reset stats') 

228 

229 def setup(self, controller, params, *args, **kwargs): 

230 params['control_order'] = 999 

231 if 'hook' not in params.keys(): 

232 from pySDC.implementations.hooks.log_solution import LogToFile 

233 

234 params['hook'] = LogToFile 

235 

236 self.counter = params['hook'].counter 

237 return super().setup(controller, params, *args, **kwargs) 

238 

239 def post_step_processing(self, controller, S, **kwargs): 

240 hook = self.params.hook 

241 

242 P = S.levels[0].prob 

243 

244 while self.counter < hook.counter: 

245 path = self.get_stats_path(index=hook.counter - 2) 

246 stats = controller.return_stats() 

247 store = True 

248 if hasattr(S.levels[0].sweep, 'comm') and S.levels[0].sweep.comm.rank > 0: 

249 store = False 

250 elif P.comm.rank > 0: 

251 store = False 

252 if store: 

253 with open(path, 'wb') as file: 

254 pickle.dump(stats, file) 

255 self.log(f'Stored stats in {path!r}', S) 

256 # print(stats) 

257 self.reset_stats(controller) 

258 self.counter = hook.counter 

259 

260 def post_run_processing(self, controller, S, **kwargs): 

261 self.post_step_processing(controller, S, **kwargs) 

262 

263 stats = self.merge_all_stats(controller) 

264 

265 def return_stats(): 

266 return stats 

267 

268 controller.return_stats = return_stats