Coverage for pySDC/implementations/sweeper_classes/generic_implicit_MPI.py: 90%
105 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 13:12 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 13:12 +0000
1from mpi4py import MPI
3from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
4from pySDC.core.sweeper import Sweeper, ParameterError
5import logging
8class SweeperMPI(Sweeper):
9 """
10 MPI based sweeper where each rank administers one collocation node. Adapt sweepers to MPI by use of multiple inheritance.
11 See for example the `generic_implicit_MPI` sweeper, which has a class definition:
13 ```
14 class generic_implicit_MPI(SweeperMPI, generic_implicit):
15 ```
17 this means in inherits both from `SweeperMPI` and `generic_implicit`. The hierarchy works such that functions are first
18 called from `SweeperMPI` and then from `generic_implicit`. For instance, in the `__init__` function, the `SweeperMPI`
19 class adds a communicator and nothing else. The `generic_implicit` implicit class adds a preconditioner and so on.
20 It's a bit confusing because `self.params` is overwritten in the second call to the `__init__` of the core `sweeper`
21 class, but the `SweeperMPI` class adds parameters to the `params` dictionary, which will again be added in
22 `generic_implicit`.
23 """
25 def __init__(self, params):
26 self.logger = logging.getLogger('sweeper')
28 if 'comm' not in params.keys():
29 params['comm'] = MPI.COMM_WORLD
30 self.logger.debug('Using MPI.COMM_WORLD for the communicator because none was supplied in the params.')
31 super().__init__(params)
33 if self.params.comm.size != self.coll.num_nodes:
34 raise NotImplementedError(
35 f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} processes.'
36 )
38 @property
39 def comm(self):
40 return self.params.comm
42 @property
43 def rank(self):
44 return self.comm.rank
46 def compute_end_point(self):
47 """
48 Compute u at the right point of the interval
50 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
52 Returns:
53 None
54 """
56 L = self.level
57 P = L.prob
59 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
60 if self.coll.right_is_node and not self.params.do_coll_update:
61 # a copy is sufficient
62 root = self.comm.Get_size() - 1
63 if self.comm.rank == root:
64 L.uend = P.dtype_u(L.u[-1])
65 else:
66 L.uend = P.dtype_u(L.u[0])
67 self.comm.Bcast(L.uend, root=root)
68 else:
69 raise NotImplementedError('require last node to be identical with right interval boundary')
71 return None
73 def compute_residual(self, stage=None):
74 """
75 Computation of the residual using the collocation matrix Q
77 Args:
78 stage (str): The current stage of the step the level belongs to
79 """
81 L = self.level
83 # Check if we want to skip the residual computation to gain performance
84 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
85 if stage in self.params.skip_residual_computation:
86 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
87 return None
89 # compute the residual for each node
91 # build QF(u)
92 res = self.integrate(last_only=L.params.residual_type[:4] == 'last')
93 res += L.u[0] - L.u[self.rank + 1]
94 # add tau if associated
95 if L.tau[self.rank] is not None:
96 res += L.tau[self.rank]
97 # use abs function from data type here
98 res_norm = abs(res)
100 # find maximal residual over the nodes
101 if L.params.residual_type == 'full_abs':
102 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX)
103 elif L.params.residual_type == 'last_abs':
104 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1)
105 elif L.params.residual_type == 'full_rel':
106 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX)
107 elif L.params.residual_type == 'last_rel':
108 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1)
109 else:
110 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!')
112 # indicate that the residual has seen the new values
113 L.status.updated = False
115 return None
117 def predict(self):
118 """
119 Predictor to fill values at nodes before first sweep
121 Default prediction for the sweepers, only copies the values to all collocation nodes
122 and evaluates the RHS of the ODE there
123 """
125 L = self.level
126 P = L.prob
128 # evaluate RHS at left point
129 L.f[0] = P.eval_f(L.u[0], L.time)
131 m = self.rank
133 if self.params.initial_guess == 'spread':
134 # copy u[0] to all collocation nodes, evaluate RHS
135 L.u[m + 1] = P.dtype_u(L.u[0])
136 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
137 elif self.params.initial_guess == 'copy':
138 # copy u[0] and RHS evaluation to all collocation nodes
139 L.u[m + 1] = P.dtype_u(L.u[0])
140 L.f[m + 1] = P.dtype_f(L.f[0])
141 elif self.params.initial_guess == 'zero':
142 # zeros solution for u and RHS
143 L.u[m + 1] = P.dtype_u(init=P.init, val=0.0)
144 L.f[m + 1] = P.dtype_f(init=P.init, val=0.0)
145 else:
146 raise ParameterError(f'initial_guess option {self.params.initial_guess} not implemented')
148 # indicate that this level is now ready for sweeps
149 L.status.unlocked = True
150 L.status.updated = True
152 def communicate_tau_correction_for_full_interval(self):
153 L = self.level
154 P = L.prob
155 if self.rank < self.comm.size - 1:
156 L.tau[-1] = P.u_init
157 self.comm.Bcast(L.tau[-1], root=self.comm.size - 1)
160class generic_implicit_MPI(SweeperMPI, generic_implicit):
161 """
162 Generic implicit sweeper parallelized across the nodes.
163 Please supply a communicator as `comm` to the parameters!
165 Attributes:
166 rank (int): MPI rank
167 """
169 def integrate(self, last_only=False):
170 """
171 Integrates the right-hand side
173 Args:
174 last_only (bool): Integrate only the last node for the residual or all of them
176 Returns:
177 list of dtype_u: containing the integral as values
178 """
179 L = self.level
180 P = L.prob
182 me = P.dtype_u(P.init, val=0.0)
183 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes):
184 recvBuf = me if m == self.rank else None
185 self.comm.Reduce(
186 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * L.f[self.rank + 1], recvBuf, root=m, op=MPI.SUM
187 )
189 return me
191 def update_nodes(self):
192 """
193 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
195 Returns:
196 None
197 """
199 L = self.level
200 P = L.prob
202 # only if the level has been touched before
203 assert L.status.unlocked
205 # update the MIN-SR-FLEX preconditioner
206 self.updateVariableCoeffs(L.status.sweep)
208 # gather all terms which are known already (e.g. from the previous iteration)
209 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
211 # get QF(u^k)
212 rhs = self.integrate()
214 rhs -= L.dt * self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1]
216 # add initial value
217 rhs += L.u[0]
218 # add tau if associated
219 if L.tau[self.rank] is not None:
220 rhs += L.tau[self.rank]
222 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
224 # implicit solve with prefactor stemming from the diagonal of Qd
225 L.u[self.rank + 1] = P.solve_system(
226 rhs,
227 L.dt * self.QI[self.rank + 1, self.rank + 1],
228 L.u[self.rank + 1],
229 L.time + L.dt * self.coll.nodes[self.rank],
230 )
231 # update function values
232 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank])
234 # indicate presence of new values at this level
235 L.status.updated = True
237 return None
239 def compute_end_point(self):
240 """
241 Compute u at the right point of the interval
243 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
245 Returns:
246 None
247 """
249 L = self.level
250 P = L.prob
252 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
253 if self.coll.right_is_node and not self.params.do_coll_update:
254 super().compute_end_point()
255 else:
256 L.uend = P.dtype_u(L.u[0])
257 self.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM)
258 L.uend += L.u[0]
260 # add up tau correction of the full interval (last entry)
261 if L.tau[self.rank] is not None:
262 self.communicate_tau_correction_for_full_interval()
263 L.uend += L.tau[-1]
264 return None