Coverage for pySDC / projects / GPU / configs / base_config.py: 73%

167 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-12 11:13 +0000

1from pySDC.core.convergence_controller import ConvergenceController 

2import pickle 

3import numpy as np 

4 

5 

6def get_config(args): 

7 name = args['config'] 

8 if name[:2] == 'GS': 

9 from pySDC.projects.GPU.configs.GS_configs import get_config as _get_config 

10 elif name[:5] == 'RBC3D': 

11 from pySDC.projects.RayleighBenard.RBC3D_configs import get_config as _get_config 

12 elif name[:3] == 'RBC': 

13 from pySDC.projects.GPU.configs.RBC_configs import get_config as _get_config 

14 else: 

15 raise NotImplementedError(f'There is no configuration called {name!r}!') 

16 

17 return _get_config(args) 

18 

19 

20def get_comms(n_procs_list, comm_world=None, _comm=None, _tot_rank=0, _rank=None, useGPU=False): 

21 from mpi4py import MPI 

22 

23 comm_world = MPI.COMM_WORLD if comm_world is None else comm_world 

24 _comm = comm_world if _comm is None else _comm 

25 _rank = comm_world.rank if _rank is None else _rank 

26 

27 if len(n_procs_list) > 0: 

28 color = _tot_rank + _rank // n_procs_list[0] 

29 new_comm = comm_world.Split(color) 

30 

31 assert new_comm.size == n_procs_list[0] 

32 

33 if useGPU: 

34 import cupy_backends 

35 

36 try: 

37 import cupy 

38 from pySDC.helpers.NCCL_communicator import NCCLComm 

39 

40 new_comm = NCCLComm(new_comm) 

41 except ( 

42 ImportError, 

43 cupy_backends.cuda.api.runtime.CUDARuntimeError, 

44 cupy_backends.cuda.libs.nccl.NcclError, 

45 ): 

46 print('Warning: Communicator is MPI instead of NCCL in spite of using GPUs!') 

47 

48 return [new_comm] + get_comms( 

49 n_procs_list[1:], 

50 comm_world, 

51 _comm=new_comm, 

52 _tot_rank=_tot_rank + _comm.size * new_comm.rank, 

53 _rank=_comm.rank // new_comm.size, 

54 useGPU=useGPU, 

55 ) 

56 else: 

57 return [] 

58 

59 

60class Config(object): 

61 sweeper_type = None 

62 Tend = None 

63 base_path = './' 

64 logging_time_increment = 0.5 

65 

66 def __init__(self, args, comm_world=None): 

67 from mpi4py import MPI 

68 

69 self.args = args 

70 self.comm_world = MPI.COMM_WORLD if comm_world is None else comm_world 

71 self.n_procs_list = args["procs"] 

72 if args['mode'] == 'run': 

73 self.comms = get_comms( 

74 n_procs_list=self.n_procs_list[::-1], useGPU=args['useGPU'], comm_world=self.comm_world 

75 )[::-1] 

76 else: 

77 self.comms = [MPI.COMM_SELF, MPI.COMM_SELF, MPI.COMM_SELF] 

78 self.ranks = [me.rank for me in self.comms] 

79 

80 def get_file_name(self): 

81 res = self.args['res'] 

82 return f'{self.base_path}/data/{type(self).__name__}-res{res}.pySDC' 

83 

84 def get_LogToFile(self, *args, **kwargs): 

85 if self.comms[1].rank > 0: 

86 return None 

87 from pySDC.implementations.hooks.log_solution import LogToFile 

88 

89 LogToFile.filename = self.get_file_name() 

90 LogToFile.time_increment = self.logging_time_increment 

91 LogToFile.allow_overwriting = True 

92 

93 return LogToFile 

94 

95 def get_description(self, *args, MPIsweeper=False, useGPU=False, **kwargs): 

96 description = {} 

97 description['problem_class'] = None 

98 description['problem_params'] = {'useGPU': useGPU, 'comm': self.comms[2]} 

99 description['sweeper_class'] = self.get_sweeper(useMPI=MPIsweeper) 

100 description['sweeper_params'] = {'initial_guess': 'copy'} 

101 description['level_params'] = {} 

102 description['step_params'] = {} 

103 description['convergence_controllers'] = {} 

104 

105 if self.get_LogToFile(): 

106 path = self.get_file_name()[:-6] 

107 description['convergence_controllers'][LogStats] = {'path': path} 

108 

109 if MPIsweeper: 

110 description['sweeper_params']['comm'] = self.comms[1] 

111 return description 

112 

113 def get_controller_params(self, *args, logger_level=15, **kwargs): 

114 from pySDC.implementations.hooks.log_work import LogWork 

115 from pySDC.implementations.hooks.log_step_size import LogStepSize 

116 from pySDC.implementations.hooks.log_restarts import LogRestarts 

117 

118 controller_params = {} 

119 controller_params['logger_level'] = logger_level if self.comm_world.rank == 0 else 40 

120 controller_params['hook_class'] = [LogWork, LogStepSize, LogRestarts] 

121 logToFile = self.get_LogToFile() 

122 if logToFile: 

123 controller_params['hook_class'] += [logToFile] 

124 controller_params['mssdc_jac'] = False 

125 return controller_params 

126 

127 def get_sweeper(self, useMPI): 

128 if useMPI and self.sweeper_type == 'IMEX': 

129 from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI as sweeper 

130 elif not useMPI and self.sweeper_type == 'IMEX': 

131 from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper 

132 elif useMPI and self.sweeper_type == 'generic_implicit': 

133 from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper 

134 elif not useMPI and self.sweeper_type == 'generic_implicit': 

135 from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper 

136 else: 

137 raise NotImplementedError(f'Don\'t know the sweeper for {self.sweeper_type=}') 

138 

139 return sweeper 

140 

141 def prepare_caches(self, prob): 

142 pass 

143 

144 def get_path(self, *args, ranks=None, **kwargs): 

145 ranks = self.ranks if ranks is None else ranks 

146 return f'{type(self).__name__}{self.args_to_str()}-{ranks[0]}-{ranks[2]}' 

147 

148 def args_to_str(self, args=None): 

149 args = self.args if args is None else args 

150 name = '' 

151 

152 name = f'{name}-res_{args["res"]}' 

153 name = f'{name}-useGPU_{args["useGPU"]}' 

154 name = f'{name}-procs_{args["procs"][0]}_{args["procs"][1]}_{args["procs"][2]}' 

155 return name 

156 

157 def plot(self, P, idx, num_procs_list): 

158 raise NotImplementedError 

159 

160 def get_initial_condition(self, P, *args, restart_idx=0, **kwargs): 

161 if restart_idx == 0: 

162 return P.u_exact(t=0), 0 

163 else: 

164 

165 from pySDC.helpers.fieldsIO import FieldsIO 

166 

167 P.setUpFieldsIO() 

168 outfile = FieldsIO.fromFile(self.get_file_name()) 

169 

170 t0, solution = outfile.readField(restart_idx) 

171 solution = solution[: P.spectral.ncomponents, ...] 

172 

173 u0 = P.u_init 

174 

175 if P.spectral_space: 

176 u0[...] = P.transform(solution) 

177 else: 

178 u0[...] = solution 

179 

180 return u0, t0 

181 

182 LogToFile = self.get_LogToFile() 

183 file = LogToFile.load(restart_idx) 

184 LogToFile.counter = restart_idx 

185 u0 = P.u_init 

186 if hasattr(P, 'spectral_space'): 

187 if P.spectral_space: 

188 u0[...] = P.transform(file['u']) 

189 else: 

190 u0[...] = file['u'] 

191 else: 

192 u0[...] = file['u'] 

193 return u0, file['t'] 

194 

195 

196class LogStats(ConvergenceController): 

197 

198 def get_stats_path(self, index=0): 

199 return f'{self.params.path}_{index:06d}-stats.pickle' 

200 

201 def merge_all_stats(self, controller): 

202 hook = self.params.hook 

203 

204 stats = {} 

205 for i in range(hook.counter - 1): 

206 try: 

207 with open(self.get_stats_path(index=i), 'rb') as file: 

208 _stats = pickle.load(file) 

209 stats = {**stats, **_stats} 

210 except (FileNotFoundError, EOFError): 

211 print(f'Warning: No stats found at path {self.get_stats_path(index=i)}') 

212 

213 stats = {**stats, **controller.return_stats()} 

214 return stats 

215 

216 def reset_stats(self, controller): 

217 for hook in controller.hooks: 

218 hook.reset_stats() 

219 self.logger.debug('Reset stats') 

220 

221 def setup(self, controller, params, *args, **kwargs): 

222 params['control_order'] = 999 

223 if 'hook' not in params.keys(): 

224 from pySDC.implementations.hooks.log_solution import LogToFile 

225 

226 params['hook'] = LogToFile 

227 

228 self.counter = params['hook'].counter 

229 return super().setup(controller, params, *args, **kwargs) 

230 

231 def post_step_processing(self, controller, S, **kwargs): 

232 hook = self.params.hook 

233 

234 P = S.levels[0].prob 

235 

236 while self.counter < hook.counter: 

237 path = self.get_stats_path(index=hook.counter - 2) 

238 stats = controller.return_stats() 

239 store = True 

240 if hasattr(S.levels[0].sweep, 'comm') and S.levels[0].sweep.comm.rank > 0: 

241 store = False 

242 elif P.comm.rank > 0: 

243 store = False 

244 if store: 

245 with open(path, 'wb') as file: 

246 pickle.dump(stats, file) 

247 self.log(f'Stored stats in {path!r}', S) 

248 # print(stats) 

249 self.reset_stats(controller) 

250 self.counter = hook.counter 

251 

252 def post_run_processing(self, controller, S, **kwargs): 

253 self.post_step_processing(controller, S, **kwargs) 

254 

255 stats = self.merge_all_stats(controller) 

256 

257 def return_stats(): 

258 return stats 

259 

260 controller.return_stats = return_stats