Coverage for pySDC / projects / GPU / configs / base_config.py: 71%
171 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 07:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-27 07:06 +0000
1from pySDC.core.convergence_controller import ConvergenceController
2import pickle
3import numpy as np
6def get_config(args):
7 name = args['config']
8 if name[:2] == 'GS':
9 from pySDC.projects.GPU.configs.GS_configs import get_config as _get_config
10 elif name[:5] == 'RBC3D':
11 from pySDC.projects.RayleighBenard.RBC3D_configs import get_config as _get_config
12 elif name[:3] == 'RBC':
13 from pySDC.projects.GPU.configs.RBC_configs import get_config as _get_config
14 else:
15 raise NotImplementedError(f'There is no configuration called {name!r}!')
17 return _get_config(args)
20def get_comms(n_procs_list, comm_world=None, _comm=None, _tot_rank=0, _rank=None, useGPU=False):
21 from mpi4py import MPI
23 comm_world = MPI.COMM_WORLD if comm_world is None else comm_world
24 _comm = comm_world if _comm is None else _comm
25 _rank = comm_world.rank if _rank is None else _rank
27 if len(n_procs_list) > 0:
28 color = _tot_rank + _rank // n_procs_list[0]
29 new_comm = comm_world.Split(color)
31 assert new_comm.size == n_procs_list[0]
33 if useGPU:
34 import cupy_backends
36 try:
37 import cupy
38 from pySDC.helpers.NCCL_communicator import NCCLComm
40 new_comm = NCCLComm(new_comm)
41 except (
42 ImportError,
43 cupy_backends.cuda.api.runtime.CUDARuntimeError,
44 cupy_backends.cuda.libs.nccl.NcclError,
45 ):
46 print('Warning: Communicator is MPI instead of NCCL in spite of using GPUs!')
48 return [new_comm] + get_comms(
49 n_procs_list[1:],
50 comm_world,
51 _comm=new_comm,
52 _tot_rank=_tot_rank + _comm.size * new_comm.rank,
53 _rank=_comm.rank // new_comm.size,
54 useGPU=useGPU,
55 )
56 else:
57 return []
60class Config(object):
61 sweeper_type = None
62 Tend = None
63 base_path = './'
64 logging_time_increment = 0.5
66 def __init__(self, args, comm_world=None):
67 from mpi4py import MPI
69 self.args = args
70 self.comm_world = MPI.COMM_WORLD if comm_world is None else comm_world
71 self.n_procs_list = args["procs"]
72 if args['mode'] in ['run', 'benchmark']:
73 distribution = args.get('distribution', 'time_first')
74 if distribution in ['space_first', 'space_major']:
75 self.comms = get_comms(
76 n_procs_list=self.n_procs_list[::-1], useGPU=args['useGPU'], comm_world=self.comm_world
77 )[::-1]
78 elif distribution in ['time_first', 'time_major']:
79 self.comms = get_comms(
80 n_procs_list=self.n_procs_list, useGPU=args['useGPU'], comm_world=self.comm_world
81 )
82 else:
83 raise NotImplementedError
84 else:
85 self.comms = [MPI.COMM_SELF, MPI.COMM_SELF, MPI.COMM_SELF]
86 self.ranks = [me.rank for me in self.comms]
88 def get_file_name(self):
89 res = self.args['res']
90 return f'{self.base_path}/data/{type(self).__name__}-res{res}.pySDC'
92 def get_LogToFile(self, *args, **kwargs):
93 if self.comms[1].rank > 0:
94 return None
95 from pySDC.implementations.hooks.log_solution import LogToFile
97 LogToFile.filename = self.get_file_name()
98 LogToFile.time_increment = self.logging_time_increment
99 LogToFile.allow_overwriting = True
101 return LogToFile
103 def get_description(self, *args, MPIsweeper=False, useGPU=False, **kwargs):
104 description = {}
105 description['problem_class'] = None
106 description['problem_params'] = {'useGPU': useGPU, 'comm': self.comms[2]}
107 description['sweeper_class'] = self.get_sweeper(useMPI=MPIsweeper)
108 description['sweeper_params'] = {'initial_guess': 'copy'}
109 description['level_params'] = {}
110 description['step_params'] = {}
111 description['convergence_controllers'] = {}
113 if self.get_LogToFile():
114 path = self.get_file_name()[:-6]
115 description['convergence_controllers'][LogStats] = {'path': path}
117 if MPIsweeper:
118 description['sweeper_params']['comm'] = self.comms[1]
119 return description
121 def get_controller_params(self, *args, logger_level=15, **kwargs):
122 from pySDC.implementations.hooks.log_work import LogWork
123 from pySDC.implementations.hooks.log_step_size import LogStepSize
124 from pySDC.implementations.hooks.log_restarts import LogRestarts
126 controller_params = {}
127 controller_params['logger_level'] = logger_level if self.comm_world.rank == 0 else 40
128 controller_params['hook_class'] = [LogWork, LogStepSize, LogRestarts]
129 logToFile = self.get_LogToFile()
130 if logToFile:
131 controller_params['hook_class'] += [logToFile]
132 controller_params['mssdc_jac'] = False
133 return controller_params
135 def get_sweeper(self, useMPI):
136 if useMPI and self.sweeper_type == 'IMEX':
137 from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI as sweeper
138 elif not useMPI and self.sweeper_type == 'IMEX':
139 from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper
140 elif useMPI and self.sweeper_type == 'generic_implicit':
141 from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper
142 elif not useMPI and self.sweeper_type == 'generic_implicit':
143 from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper
144 else:
145 raise NotImplementedError(f'Don\'t know the sweeper for {self.sweeper_type=}')
147 return sweeper
149 def prepare_caches(self, prob):
150 pass
152 def get_path(self, *args, ranks=None, **kwargs):
153 ranks = self.ranks if ranks is None else ranks
154 return f'{type(self).__name__}{self.args_to_str()}-{ranks[0]}-{ranks[2]}'
156 def args_to_str(self, args=None):
157 args = self.args if args is None else args
158 name = ''
160 name = f'{name}-res_{args["res"]}'
161 name = f'{name}-useGPU_{args["useGPU"]}'
162 name = f'{name}-procs_{args["procs"][0]}_{args["procs"][1]}_{args["procs"][2]}'
163 return name
165 def plot(self, P, idx, num_procs_list):
166 raise NotImplementedError
168 def get_initial_condition(self, P, *args, restart_idx=0, **kwargs):
169 if restart_idx == 0:
170 return P.u_exact(t=0), 0
171 else:
173 from pySDC.helpers.fieldsIO import FieldsIO
175 P.setUpFieldsIO()
176 outfile = FieldsIO.fromFile(self.get_file_name())
178 t0, solution = outfile.readField(restart_idx)
179 solution = solution[: P.spectral.ncomponents, ...]
181 u0 = P.u_init
183 if P.spectral_space:
184 u0[...] = P.transform(solution)
185 else:
186 u0[...] = solution
188 return u0, t0
190 LogToFile = self.get_LogToFile()
191 file = LogToFile.load(restart_idx)
192 LogToFile.counter = restart_idx
193 u0 = P.u_init
194 if hasattr(P, 'spectral_space'):
195 if P.spectral_space:
196 u0[...] = P.transform(file['u'])
197 else:
198 u0[...] = file['u']
199 else:
200 u0[...] = file['u']
201 return u0, file['t']
204class LogStats(ConvergenceController):
206 def get_stats_path(self, index=0):
207 return f'{self.params.path}_{index:06d}-stats.pickle'
209 def merge_all_stats(self, controller):
210 hook = self.params.hook
212 stats = {}
213 for i in range(hook.counter - 1):
214 try:
215 with open(self.get_stats_path(index=i), 'rb') as file:
216 _stats = pickle.load(file)
217 stats = {**stats, **_stats}
218 except (FileNotFoundError, EOFError):
219 print(f'Warning: No stats found at path {self.get_stats_path(index=i)}')
221 stats = {**stats, **controller.return_stats()}
222 return stats
224 def reset_stats(self, controller):
225 for hook in controller.hooks:
226 hook.reset_stats()
227 self.logger.debug('Reset stats')
229 def setup(self, controller, params, *args, **kwargs):
230 params['control_order'] = 999
231 if 'hook' not in params.keys():
232 from pySDC.implementations.hooks.log_solution import LogToFile
234 params['hook'] = LogToFile
236 self.counter = params['hook'].counter
237 return super().setup(controller, params, *args, **kwargs)
239 def post_step_processing(self, controller, S, **kwargs):
240 hook = self.params.hook
242 P = S.levels[0].prob
244 while self.counter < hook.counter:
245 path = self.get_stats_path(index=hook.counter - 2)
246 stats = controller.return_stats()
247 store = True
248 if hasattr(S.levels[0].sweep, 'comm') and S.levels[0].sweep.comm.rank > 0:
249 store = False
250 elif P.comm.rank > 0:
251 store = False
252 if store:
253 with open(path, 'wb') as file:
254 pickle.dump(stats, file)
255 self.log(f'Stored stats in {path!r}', S)
256 # print(stats)
257 self.reset_stats(controller)
258 self.counter = hook.counter
260 def post_run_processing(self, controller, S, **kwargs):
261 self.post_step_processing(controller, S, **kwargs)
263 stats = self.merge_all_stats(controller)
265 def return_stats():
266 return stats
268 controller.return_stats = return_stats