Coverage for pySDC/implementations/sweeper_classes/generic_implicit_MPI.py: 93%
98 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1from mpi4py import MPI
3from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
4from pySDC.core.Sweeper import sweeper, ParameterError
5import logging
8class SweeperMPI(sweeper):
9 """
10 MPI based sweeper where each rank administers one collocation node. Adapt sweepers to MPI by use of multiple inheritance.
11 See for example the `generic_implicit_MPI` sweeper, which has a class definition:
13 ```
14 class generic_implicit_MPI(SweeperMPI, generic_implicit):
15 ```
17 this means in inherits both from `SweeperMPI` and `generic_implicit`. The hierarchy works such that functions are first
18 called from `SweeperMPI` and then from `generic_implicit`. For instance, in the `__init__` function, the `SweeperMPI`
19 class adds a communicator and nothing else. The `generic_implicit` implicit class adds a preconditioner and so on.
20 It's a bit confusing because `self.params` is overwritten in the second call to the `__init__` of the core `sweeper`
21 class, but the `SweeperMPI` class adds parameters to the `params` dictionary, which will again be added in
22 `generic_implicit`.
23 """
25 def __init__(self, params):
26 self.logger = logging.getLogger('sweeper')
28 if 'comm' not in params.keys():
29 params['comm'] = MPI.COMM_WORLD
30 self.logger.debug('Using MPI.COMM_WORLD for the communicator because none was supplied in the params.')
31 super().__init__(params)
33 if self.params.comm.size != self.coll.num_nodes:
34 raise NotImplementedError(
35 f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} processes.'
36 )
38 @property
39 def comm(self):
40 return self.params.comm
42 @property
43 def rank(self):
44 return self.comm.rank
46 def compute_end_point(self):
47 """
48 Compute u at the right point of the interval
50 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
52 Returns:
53 None
54 """
56 L = self.level
57 P = L.prob
58 L.uend = P.dtype_u(P.init, val=0.0)
60 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
61 if self.coll.right_is_node and not self.params.do_coll_update:
62 # a copy is sufficient
63 root = self.comm.Get_size() - 1
64 if self.comm.rank == root:
65 L.uend[:] = L.u[-1]
66 self.comm.Bcast(L.uend, root=root)
67 else:
68 raise NotImplementedError('require last node to be identical with right interval boundary')
70 return None
72 def compute_residual(self, stage=None):
73 """
74 Computation of the residual using the collocation matrix Q
76 Args:
77 stage (str): The current stage of the step the level belongs to
78 """
80 L = self.level
82 # Check if we want to skip the residual computation to gain performance
83 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
84 if stage in self.params.skip_residual_computation:
85 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
86 return None
88 # compute the residual for each node
90 # build QF(u)
91 res = self.integrate(last_only=L.params.residual_type[:4] == 'last')
92 res += L.u[0] - L.u[self.rank + 1]
93 # add tau if associated
94 if L.tau[self.rank] is not None:
95 res += L.tau[self.rank]
96 # use abs function from data type here
97 res_norm = abs(res)
99 # find maximal residual over the nodes
100 if L.params.residual_type == 'full_abs':
101 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX)
102 elif L.params.residual_type == 'last_abs':
103 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1)
104 elif L.params.residual_type == 'full_rel':
105 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX)
106 elif L.params.residual_type == 'last_rel':
107 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1)
108 else:
109 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!')
111 # indicate that the residual has seen the new values
112 L.status.updated = False
114 return None
116 def predict(self):
117 """
118 Predictor to fill values at nodes before first sweep
120 Default prediction for the sweepers, only copies the values to all collocation nodes
121 and evaluates the RHS of the ODE there
122 """
124 L = self.level
125 P = L.prob
127 # evaluate RHS at left point
128 L.f[0] = P.eval_f(L.u[0], L.time)
130 m = self.rank
132 if self.params.initial_guess == 'spread':
133 # copy u[0] to all collocation nodes, evaluate RHS
134 L.u[m + 1] = P.dtype_u(L.u[0])
135 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
136 elif self.params.initial_guess == 'copy':
137 # copy u[0] and RHS evaluation to all collocation nodes
138 L.u[m + 1] = P.dtype_u(L.u[0])
139 L.f[m + 1] = P.dtype_f(L.f[0])
140 elif self.params.initial_guess == 'zero':
141 # zeros solution for u and RHS
142 L.u[m + 1] = P.dtype_u(init=P.init, val=0.0)
143 L.f[m + 1] = P.dtype_f(init=P.init, val=0.0)
144 else:
145 raise ParameterError(f'initial_guess option {self.params.initial_guess} not implemented')
147 # indicate that this level is now ready for sweeps
148 L.status.unlocked = True
149 L.status.updated = True
152class generic_implicit_MPI(SweeperMPI, generic_implicit):
153 """
154 Generic implicit sweeper parallelized across the nodes.
155 Please supply a communicator as `comm` to the parameters!
157 Attributes:
158 rank (int): MPI rank
159 """
161 def integrate(self, last_only=False):
162 """
163 Integrates the right-hand side
165 Args:
166 last_only (bool): Integrate only the last node for the residual or all of them
168 Returns:
169 list of dtype_u: containing the integral as values
170 """
171 L = self.level
172 P = L.prob
174 me = P.dtype_u(P.init, val=0.0)
175 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes):
176 recvBuf = me if m == self.rank else None
177 self.comm.Reduce(
178 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * L.f[self.rank + 1], recvBuf, root=m, op=MPI.SUM
179 )
181 return me
183 def update_nodes(self):
184 """
185 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
187 Returns:
188 None
189 """
191 L = self.level
192 P = L.prob
194 # only if the level has been touched before
195 assert L.status.unlocked
197 # get number of collocation nodes for easier access
199 # gather all terms which are known already (e.g. from the previous iteration)
200 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
202 # get QF(u^k)
203 rhs = self.integrate()
205 rhs -= L.dt * self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1]
207 # add initial value
208 rhs += L.u[0]
209 # add tau if associated
210 if L.tau[self.rank] is not None:
211 rhs += L.tau[self.rank]
213 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
215 # implicit solve with prefactor stemming from the diagonal of Qd
216 L.u[self.rank + 1] = P.solve_system(
217 rhs,
218 L.dt * self.QI[self.rank + 1, self.rank + 1],
219 L.u[self.rank + 1],
220 L.time + L.dt * self.coll.nodes[self.rank],
221 )
222 # update function values
223 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank])
225 # indicate presence of new values at this level
226 L.status.updated = True
228 return None
230 def compute_end_point(self):
231 """
232 Compute u at the right point of the interval
234 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
236 Returns:
237 None
238 """
240 L = self.level
241 P = L.prob
242 L.uend = P.dtype_u(P.init, val=0.0)
244 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
245 if self.coll.right_is_node and not self.params.do_coll_update:
246 super().compute_end_point()
247 else:
248 L.uend = P.dtype_u(L.u[0])
249 self.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM)
250 L.uend += L.u[0]
252 # add up tau correction of the full interval (last entry)
253 if L.tau[-1] is not None:
254 L.uend += L.tau[-1]
255 return None