Coverage for pySDC/implementations/sweeper_classes/generic_implicit_MPI.py: 90%

105 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 13:12 +0000

1from mpi4py import MPI 

2 

3from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

4from pySDC.core.sweeper import Sweeper, ParameterError 

5import logging 

6 

7 

8class SweeperMPI(Sweeper): 

9 """ 

10 MPI based sweeper where each rank administers one collocation node. Adapt sweepers to MPI by use of multiple inheritance. 

11 See for example the `generic_implicit_MPI` sweeper, which has a class definition: 

12 

13 ``` 

14 class generic_implicit_MPI(SweeperMPI, generic_implicit): 

15 ``` 

16 

17 this means in inherits both from `SweeperMPI` and `generic_implicit`. The hierarchy works such that functions are first 

18 called from `SweeperMPI` and then from `generic_implicit`. For instance, in the `__init__` function, the `SweeperMPI` 

19 class adds a communicator and nothing else. The `generic_implicit` implicit class adds a preconditioner and so on. 

20 It's a bit confusing because `self.params` is overwritten in the second call to the `__init__` of the core `sweeper` 

21 class, but the `SweeperMPI` class adds parameters to the `params` dictionary, which will again be added in 

22 `generic_implicit`. 

23 """ 

24 

25 def __init__(self, params): 

26 self.logger = logging.getLogger('sweeper') 

27 

28 if 'comm' not in params.keys(): 

29 params['comm'] = MPI.COMM_WORLD 

30 self.logger.debug('Using MPI.COMM_WORLD for the communicator because none was supplied in the params.') 

31 super().__init__(params) 

32 

33 if self.params.comm.size != self.coll.num_nodes: 

34 raise NotImplementedError( 

35 f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} processes.' 

36 ) 

37 

38 @property 

39 def comm(self): 

40 return self.params.comm 

41 

42 @property 

43 def rank(self): 

44 return self.comm.rank 

45 

46 def compute_end_point(self): 

47 """ 

48 Compute u at the right point of the interval 

49 

50 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False 

51 

52 Returns: 

53 None 

54 """ 

55 

56 L = self.level 

57 P = L.prob 

58 

59 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy 

60 if self.coll.right_is_node and not self.params.do_coll_update: 

61 # a copy is sufficient 

62 root = self.comm.Get_size() - 1 

63 if self.comm.rank == root: 

64 L.uend = P.dtype_u(L.u[-1]) 

65 else: 

66 L.uend = P.dtype_u(L.u[0]) 

67 self.comm.Bcast(L.uend, root=root) 

68 else: 

69 raise NotImplementedError('require last node to be identical with right interval boundary') 

70 

71 return None 

72 

73 def compute_residual(self, stage=None): 

74 """ 

75 Computation of the residual using the collocation matrix Q 

76 

77 Args: 

78 stage (str): The current stage of the step the level belongs to 

79 """ 

80 

81 L = self.level 

82 

83 # Check if we want to skip the residual computation to gain performance 

84 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual! 

85 if stage in self.params.skip_residual_computation: 

86 L.status.residual = 0.0 if L.status.residual is None else L.status.residual 

87 return None 

88 

89 # compute the residual for each node 

90 

91 # build QF(u) 

92 res = self.integrate(last_only=L.params.residual_type[:4] == 'last') 

93 res += L.u[0] - L.u[self.rank + 1] 

94 # add tau if associated 

95 if L.tau[self.rank] is not None: 

96 res += L.tau[self.rank] 

97 # use abs function from data type here 

98 res_norm = abs(res) 

99 

100 # find maximal residual over the nodes 

101 if L.params.residual_type == 'full_abs': 

102 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX) 

103 elif L.params.residual_type == 'last_abs': 

104 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1) 

105 elif L.params.residual_type == 'full_rel': 

106 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX) 

107 elif L.params.residual_type == 'last_rel': 

108 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1) 

109 else: 

110 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!') 

111 

112 # indicate that the residual has seen the new values 

113 L.status.updated = False 

114 

115 return None 

116 

117 def predict(self): 

118 """ 

119 Predictor to fill values at nodes before first sweep 

120 

121 Default prediction for the sweepers, only copies the values to all collocation nodes 

122 and evaluates the RHS of the ODE there 

123 """ 

124 

125 L = self.level 

126 P = L.prob 

127 

128 # evaluate RHS at left point 

129 L.f[0] = P.eval_f(L.u[0], L.time) 

130 

131 m = self.rank 

132 

133 if self.params.initial_guess == 'spread': 

134 # copy u[0] to all collocation nodes, evaluate RHS 

135 L.u[m + 1] = P.dtype_u(L.u[0]) 

136 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

137 elif self.params.initial_guess == 'copy': 

138 # copy u[0] and RHS evaluation to all collocation nodes 

139 L.u[m + 1] = P.dtype_u(L.u[0]) 

140 L.f[m + 1] = P.dtype_f(L.f[0]) 

141 elif self.params.initial_guess == 'zero': 

142 # zeros solution for u and RHS 

143 L.u[m + 1] = P.dtype_u(init=P.init, val=0.0) 

144 L.f[m + 1] = P.dtype_f(init=P.init, val=0.0) 

145 else: 

146 raise ParameterError(f'initial_guess option {self.params.initial_guess} not implemented') 

147 

148 # indicate that this level is now ready for sweeps 

149 L.status.unlocked = True 

150 L.status.updated = True 

151 

152 def communicate_tau_correction_for_full_interval(self): 

153 L = self.level 

154 P = L.prob 

155 if self.rank < self.comm.size - 1: 

156 L.tau[-1] = P.u_init 

157 self.comm.Bcast(L.tau[-1], root=self.comm.size - 1) 

158 

159 

160class generic_implicit_MPI(SweeperMPI, generic_implicit): 

161 """ 

162 Generic implicit sweeper parallelized across the nodes. 

163 Please supply a communicator as `comm` to the parameters! 

164 

165 Attributes: 

166 rank (int): MPI rank 

167 """ 

168 

169 def integrate(self, last_only=False): 

170 """ 

171 Integrates the right-hand side 

172 

173 Args: 

174 last_only (bool): Integrate only the last node for the residual or all of them 

175 

176 Returns: 

177 list of dtype_u: containing the integral as values 

178 """ 

179 L = self.level 

180 P = L.prob 

181 

182 me = P.dtype_u(P.init, val=0.0) 

183 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes): 

184 recvBuf = me if m == self.rank else None 

185 self.comm.Reduce( 

186 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * L.f[self.rank + 1], recvBuf, root=m, op=MPI.SUM 

187 ) 

188 

189 return me 

190 

191 def update_nodes(self): 

192 """ 

193 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes 

194 

195 Returns: 

196 None 

197 """ 

198 

199 L = self.level 

200 P = L.prob 

201 

202 # only if the level has been touched before 

203 assert L.status.unlocked 

204 

205 # update the MIN-SR-FLEX preconditioner 

206 self.updateVariableCoeffs(L.status.sweep) 

207 

208 # gather all terms which are known already (e.g. from the previous iteration) 

209 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau 

210 

211 # get QF(u^k) 

212 rhs = self.integrate() 

213 

214 rhs -= L.dt * self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1] 

215 

216 # add initial value 

217 rhs += L.u[0] 

218 # add tau if associated 

219 if L.tau[self.rank] is not None: 

220 rhs += L.tau[self.rank] 

221 

222 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1) 

223 

224 # implicit solve with prefactor stemming from the diagonal of Qd 

225 L.u[self.rank + 1] = P.solve_system( 

226 rhs, 

227 L.dt * self.QI[self.rank + 1, self.rank + 1], 

228 L.u[self.rank + 1], 

229 L.time + L.dt * self.coll.nodes[self.rank], 

230 ) 

231 # update function values 

232 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank]) 

233 

234 # indicate presence of new values at this level 

235 L.status.updated = True 

236 

237 return None 

238 

239 def compute_end_point(self): 

240 """ 

241 Compute u at the right point of the interval 

242 

243 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False 

244 

245 Returns: 

246 None 

247 """ 

248 

249 L = self.level 

250 P = L.prob 

251 

252 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy 

253 if self.coll.right_is_node and not self.params.do_coll_update: 

254 super().compute_end_point() 

255 else: 

256 L.uend = P.dtype_u(L.u[0]) 

257 self.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM) 

258 L.uend += L.u[0] 

259 

260 # add up tau correction of the full interval (last entry) 

261 if L.tau[self.rank] is not None: 

262 self.communicate_tau_correction_for_full_interval() 

263 L.uend += L.tau[-1] 

264 return None