Coverage for pySDC / projects / GPU / configs / base_config.py: 73%
167 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-12 11:13 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-12 11:13 +0000
1from pySDC.core.convergence_controller import ConvergenceController
2import pickle
3import numpy as np
6def get_config(args):
7 name = args['config']
8 if name[:2] == 'GS':
9 from pySDC.projects.GPU.configs.GS_configs import get_config as _get_config
10 elif name[:5] == 'RBC3D':
11 from pySDC.projects.RayleighBenard.RBC3D_configs import get_config as _get_config
12 elif name[:3] == 'RBC':
13 from pySDC.projects.GPU.configs.RBC_configs import get_config as _get_config
14 else:
15 raise NotImplementedError(f'There is no configuration called {name!r}!')
17 return _get_config(args)
20def get_comms(n_procs_list, comm_world=None, _comm=None, _tot_rank=0, _rank=None, useGPU=False):
21 from mpi4py import MPI
23 comm_world = MPI.COMM_WORLD if comm_world is None else comm_world
24 _comm = comm_world if _comm is None else _comm
25 _rank = comm_world.rank if _rank is None else _rank
27 if len(n_procs_list) > 0:
28 color = _tot_rank + _rank // n_procs_list[0]
29 new_comm = comm_world.Split(color)
31 assert new_comm.size == n_procs_list[0]
33 if useGPU:
34 import cupy_backends
36 try:
37 import cupy
38 from pySDC.helpers.NCCL_communicator import NCCLComm
40 new_comm = NCCLComm(new_comm)
41 except (
42 ImportError,
43 cupy_backends.cuda.api.runtime.CUDARuntimeError,
44 cupy_backends.cuda.libs.nccl.NcclError,
45 ):
46 print('Warning: Communicator is MPI instead of NCCL in spite of using GPUs!')
48 return [new_comm] + get_comms(
49 n_procs_list[1:],
50 comm_world,
51 _comm=new_comm,
52 _tot_rank=_tot_rank + _comm.size * new_comm.rank,
53 _rank=_comm.rank // new_comm.size,
54 useGPU=useGPU,
55 )
56 else:
57 return []
60class Config(object):
61 sweeper_type = None
62 Tend = None
63 base_path = './'
64 logging_time_increment = 0.5
66 def __init__(self, args, comm_world=None):
67 from mpi4py import MPI
69 self.args = args
70 self.comm_world = MPI.COMM_WORLD if comm_world is None else comm_world
71 self.n_procs_list = args["procs"]
72 if args['mode'] == 'run':
73 self.comms = get_comms(
74 n_procs_list=self.n_procs_list[::-1], useGPU=args['useGPU'], comm_world=self.comm_world
75 )[::-1]
76 else:
77 self.comms = [MPI.COMM_SELF, MPI.COMM_SELF, MPI.COMM_SELF]
78 self.ranks = [me.rank for me in self.comms]
80 def get_file_name(self):
81 res = self.args['res']
82 return f'{self.base_path}/data/{type(self).__name__}-res{res}.pySDC'
84 def get_LogToFile(self, *args, **kwargs):
85 if self.comms[1].rank > 0:
86 return None
87 from pySDC.implementations.hooks.log_solution import LogToFile
89 LogToFile.filename = self.get_file_name()
90 LogToFile.time_increment = self.logging_time_increment
91 LogToFile.allow_overwriting = True
93 return LogToFile
95 def get_description(self, *args, MPIsweeper=False, useGPU=False, **kwargs):
96 description = {}
97 description['problem_class'] = None
98 description['problem_params'] = {'useGPU': useGPU, 'comm': self.comms[2]}
99 description['sweeper_class'] = self.get_sweeper(useMPI=MPIsweeper)
100 description['sweeper_params'] = {'initial_guess': 'copy'}
101 description['level_params'] = {}
102 description['step_params'] = {}
103 description['convergence_controllers'] = {}
105 if self.get_LogToFile():
106 path = self.get_file_name()[:-6]
107 description['convergence_controllers'][LogStats] = {'path': path}
109 if MPIsweeper:
110 description['sweeper_params']['comm'] = self.comms[1]
111 return description
113 def get_controller_params(self, *args, logger_level=15, **kwargs):
114 from pySDC.implementations.hooks.log_work import LogWork
115 from pySDC.implementations.hooks.log_step_size import LogStepSize
116 from pySDC.implementations.hooks.log_restarts import LogRestarts
118 controller_params = {}
119 controller_params['logger_level'] = logger_level if self.comm_world.rank == 0 else 40
120 controller_params['hook_class'] = [LogWork, LogStepSize, LogRestarts]
121 logToFile = self.get_LogToFile()
122 if logToFile:
123 controller_params['hook_class'] += [logToFile]
124 controller_params['mssdc_jac'] = False
125 return controller_params
127 def get_sweeper(self, useMPI):
128 if useMPI and self.sweeper_type == 'IMEX':
129 from pySDC.implementations.sweeper_classes.imex_1st_order_MPI import imex_1st_order_MPI as sweeper
130 elif not useMPI and self.sweeper_type == 'IMEX':
131 from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper
132 elif useMPI and self.sweeper_type == 'generic_implicit':
133 from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper
134 elif not useMPI and self.sweeper_type == 'generic_implicit':
135 from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper
136 else:
137 raise NotImplementedError(f'Don\'t know the sweeper for {self.sweeper_type=}')
139 return sweeper
141 def prepare_caches(self, prob):
142 pass
144 def get_path(self, *args, ranks=None, **kwargs):
145 ranks = self.ranks if ranks is None else ranks
146 return f'{type(self).__name__}{self.args_to_str()}-{ranks[0]}-{ranks[2]}'
148 def args_to_str(self, args=None):
149 args = self.args if args is None else args
150 name = ''
152 name = f'{name}-res_{args["res"]}'
153 name = f'{name}-useGPU_{args["useGPU"]}'
154 name = f'{name}-procs_{args["procs"][0]}_{args["procs"][1]}_{args["procs"][2]}'
155 return name
157 def plot(self, P, idx, num_procs_list):
158 raise NotImplementedError
160 def get_initial_condition(self, P, *args, restart_idx=0, **kwargs):
161 if restart_idx == 0:
162 return P.u_exact(t=0), 0
163 else:
165 from pySDC.helpers.fieldsIO import FieldsIO
167 P.setUpFieldsIO()
168 outfile = FieldsIO.fromFile(self.get_file_name())
170 t0, solution = outfile.readField(restart_idx)
171 solution = solution[: P.spectral.ncomponents, ...]
173 u0 = P.u_init
175 if P.spectral_space:
176 u0[...] = P.transform(solution)
177 else:
178 u0[...] = solution
180 return u0, t0
182 LogToFile = self.get_LogToFile()
183 file = LogToFile.load(restart_idx)
184 LogToFile.counter = restart_idx
185 u0 = P.u_init
186 if hasattr(P, 'spectral_space'):
187 if P.spectral_space:
188 u0[...] = P.transform(file['u'])
189 else:
190 u0[...] = file['u']
191 else:
192 u0[...] = file['u']
193 return u0, file['t']
196class LogStats(ConvergenceController):
198 def get_stats_path(self, index=0):
199 return f'{self.params.path}_{index:06d}-stats.pickle'
201 def merge_all_stats(self, controller):
202 hook = self.params.hook
204 stats = {}
205 for i in range(hook.counter - 1):
206 try:
207 with open(self.get_stats_path(index=i), 'rb') as file:
208 _stats = pickle.load(file)
209 stats = {**stats, **_stats}
210 except (FileNotFoundError, EOFError):
211 print(f'Warning: No stats found at path {self.get_stats_path(index=i)}')
213 stats = {**stats, **controller.return_stats()}
214 return stats
216 def reset_stats(self, controller):
217 for hook in controller.hooks:
218 hook.reset_stats()
219 self.logger.debug('Reset stats')
221 def setup(self, controller, params, *args, **kwargs):
222 params['control_order'] = 999
223 if 'hook' not in params.keys():
224 from pySDC.implementations.hooks.log_solution import LogToFile
226 params['hook'] = LogToFile
228 self.counter = params['hook'].counter
229 return super().setup(controller, params, *args, **kwargs)
231 def post_step_processing(self, controller, S, **kwargs):
232 hook = self.params.hook
234 P = S.levels[0].prob
236 while self.counter < hook.counter:
237 path = self.get_stats_path(index=hook.counter - 2)
238 stats = controller.return_stats()
239 store = True
240 if hasattr(S.levels[0].sweep, 'comm') and S.levels[0].sweep.comm.rank > 0:
241 store = False
242 elif P.comm.rank > 0:
243 store = False
244 if store:
245 with open(path, 'wb') as file:
246 pickle.dump(stats, file)
247 self.log(f'Stored stats in {path!r}', S)
248 # print(stats)
249 self.reset_stats(controller)
250 self.counter = hook.counter
252 def post_run_processing(self, controller, S, **kwargs):
253 self.post_step_processing(controller, S, **kwargs)
255 stats = self.merge_all_stats(controller)
257 def return_stats():
258 return stats
260 controller.return_stats = return_stats