Coverage for pySDC/implementations/sweeper_classes/generic_implicit_MPI.py: 89%
104 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-04 15:08 +0000
1from mpi4py import MPI
3from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
4from pySDC.core.sweeper import Sweeper, ParameterError
5import logging
8class SweeperMPI(Sweeper):
9 """
10 MPI based sweeper where each rank administers one collocation node. Adapt sweepers to MPI by use of multiple inheritance.
11 See for example the `generic_implicit_MPI` sweeper, which has a class definition:
13 ```
14 class generic_implicit_MPI(SweeperMPI, generic_implicit):
15 ```
17 this means in inherits both from `SweeperMPI` and `generic_implicit`. The hierarchy works such that functions are first
18 called from `SweeperMPI` and then from `generic_implicit`. For instance, in the `__init__` function, the `SweeperMPI`
19 class adds a communicator and nothing else. The `generic_implicit` implicit class adds a preconditioner and so on.
20 It's a bit confusing because `self.params` is overwritten in the second call to the `__init__` of the core `sweeper`
21 class, but the `SweeperMPI` class adds parameters to the `params` dictionary, which will again be added in
22 `generic_implicit`.
23 """
25 def __init__(self, params, level):
26 """
27 Initialization routine for the sweeper
29 Args:
30 params: parameters for the sweeper
31 level (pySDC.Level.level): the level that uses this sweeper
32 """
33 self.logger = logging.getLogger('sweeper')
35 if 'comm' not in params.keys():
36 params['comm'] = MPI.COMM_WORLD
37 self.logger.debug('Using MPI.COMM_WORLD for the communicator because none was supplied in the params.')
38 super().__init__(params, level)
40 if self.params.comm.size != self.coll.num_nodes:
41 raise NotImplementedError(
42 f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} processes.'
43 )
45 @property
46 def comm(self):
47 return self.params.comm
49 @property
50 def rank(self):
51 return self.comm.rank
53 def compute_end_point(self):
54 """
55 Compute u at the right point of the interval
57 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
59 Returns:
60 None
61 """
63 L = self.level
64 P = L.prob
66 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
67 if self.coll.right_is_node and not self.params.do_coll_update:
68 # a copy is sufficient
69 root = self.comm.Get_size() - 1
70 if self.comm.rank == root:
71 L.uend = P.dtype_u(L.u[-1])
72 else:
73 L.uend = P.dtype_u(L.u[0])
74 self.comm.Bcast(L.uend, root=root)
75 else:
76 raise NotImplementedError('require last node to be identical with right interval boundary')
78 return None
80 def compute_residual(self, stage=None):
81 """
82 Computation of the residual using the collocation matrix Q
84 Args:
85 stage (str): The current stage of the step the level belongs to
86 """
88 L = self.level
90 # Check if we want to skip the residual computation to gain performance
91 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
92 if stage in self.params.skip_residual_computation:
93 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
94 return None
96 # compute the residual for each node
98 # build QF(u)
99 res = self.integrate(last_only=L.params.residual_type[:4] == 'last')
100 res += L.u[0] - L.u[self.rank + 1]
101 # add tau if associated
102 if L.tau[self.rank] is not None:
103 res += L.tau[self.rank]
104 # use abs function from data type here
105 res_norm = abs(res)
107 # find maximal residual over the nodes
108 if L.params.residual_type == 'full_abs':
109 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX)
110 elif L.params.residual_type == 'last_abs':
111 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1)
112 elif L.params.residual_type == 'full_rel':
113 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX)
114 elif L.params.residual_type == 'last_rel':
115 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1)
116 else:
117 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!')
119 # indicate that the residual has seen the new values
120 L.status.updated = False
122 return None
124 def predict(self):
125 """
126 Predictor to fill values at nodes before first sweep
128 Default prediction for the sweepers, only copies the values to all collocation nodes
129 and evaluates the RHS of the ODE there
130 """
132 L = self.level
133 P = L.prob
135 # evaluate RHS at left point
136 L.f[0] = P.eval_f(L.u[0], L.time)
138 m = self.rank
140 if self.params.initial_guess == 'spread':
141 # copy u[0] to all collocation nodes, evaluate RHS
142 L.u[m + 1] = P.dtype_u(L.u[0])
143 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
144 elif self.params.initial_guess == 'copy':
145 # copy u[0] and RHS evaluation to all collocation nodes
146 L.u[m + 1] = P.dtype_u(L.u[0])
147 L.f[m + 1] = P.dtype_f(L.f[0])
148 elif self.params.initial_guess == 'zero':
149 # zeros solution for u and RHS
150 L.u[m + 1] = P.dtype_u(init=P.init, val=0.0)
151 L.f[m + 1] = P.dtype_f(init=P.init, val=0.0)
152 else:
153 raise ParameterError(f'initial_guess option {self.params.initial_guess} not implemented')
155 # indicate that this level is now ready for sweeps
156 L.status.unlocked = True
157 L.status.updated = True
159 def communicate_tau_correction_for_full_interval(self):
160 L = self.level
161 P = L.prob
162 if self.rank < self.comm.size - 1:
163 L.tau[-1] = P.u_init
164 self.comm.Bcast(L.tau[-1], root=self.comm.size - 1)
167class generic_implicit_MPI(SweeperMPI, generic_implicit):
168 """
169 Generic implicit sweeper parallelized across the nodes.
170 Please supply a communicator as `comm` to the parameters!
172 Attributes:
173 rank (int): MPI rank
174 """
176 def integrate(self, last_only=False):
177 """
178 Integrates the right-hand side
180 Args:
181 last_only (bool): Integrate only the last node for the residual or all of them
183 Returns:
184 list of dtype_u: containing the integral as values
185 """
186 L = self.level
187 P = L.prob
189 me = P.dtype_u(P.init, val=0.0)
190 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes):
191 recvBuf = me if m == self.rank else None
192 self.comm.Reduce(
193 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * L.f[self.rank + 1], recvBuf, root=m, op=MPI.SUM
194 )
196 return me
198 def update_nodes(self):
199 """
200 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
202 Returns:
203 None
204 """
206 L = self.level
207 P = L.prob
209 # only if the level has been touched before
210 assert L.status.unlocked
212 # gather all terms which are known already (e.g. from the previous iteration)
213 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
215 # get QF(u^k)
216 rhs = self.integrate()
218 rhs -= L.dt * self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1]
220 # add initial value
221 rhs += L.u[0]
222 # add tau if associated
223 if L.tau[self.rank] is not None:
224 rhs += L.tau[self.rank]
226 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
228 # implicit solve with prefactor stemming from the diagonal of Qd
229 L.u[self.rank + 1] = P.solve_system(
230 rhs,
231 L.dt * self.QI[self.rank + 1, self.rank + 1],
232 L.u[self.rank + 1],
233 L.time + L.dt * self.coll.nodes[self.rank],
234 )
235 # update function values
236 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank])
238 # indicate presence of new values at this level
239 L.status.updated = True
241 return None
243 def compute_end_point(self):
244 """
245 Compute u at the right point of the interval
247 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
249 Returns:
250 None
251 """
253 L = self.level
254 P = L.prob
256 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
257 if self.coll.right_is_node and not self.params.do_coll_update:
258 super().compute_end_point()
259 else:
260 L.uend = P.dtype_u(L.u[0])
261 self.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM)
262 L.uend += L.u[0]
264 # add up tau correction of the full interval (last entry)
265 if L.tau[self.rank] is not None:
266 self.communicate_tau_correction_for_full_interval()
267 L.uend += L.tau[-1]
268 return None