Coverage for pySDC/implementations/sweeper_classes/generic_implicit_MPI.py: 89%

104 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-04 15:08 +0000

1from mpi4py import MPI 

2 

3from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit 

4from pySDC.core.sweeper import Sweeper, ParameterError 

5import logging 

6 

7 

8class SweeperMPI(Sweeper): 

9 """ 

10 MPI based sweeper where each rank administers one collocation node. Adapt sweepers to MPI by use of multiple inheritance. 

11 See for example the `generic_implicit_MPI` sweeper, which has a class definition: 

12 

13 ``` 

14 class generic_implicit_MPI(SweeperMPI, generic_implicit): 

15 ``` 

16 

17 this means in inherits both from `SweeperMPI` and `generic_implicit`. The hierarchy works such that functions are first 

18 called from `SweeperMPI` and then from `generic_implicit`. For instance, in the `__init__` function, the `SweeperMPI` 

19 class adds a communicator and nothing else. The `generic_implicit` implicit class adds a preconditioner and so on. 

20 It's a bit confusing because `self.params` is overwritten in the second call to the `__init__` of the core `sweeper` 

21 class, but the `SweeperMPI` class adds parameters to the `params` dictionary, which will again be added in 

22 `generic_implicit`. 

23 """ 

24 

25 def __init__(self, params, level): 

26 """ 

27 Initialization routine for the sweeper 

28 

29 Args: 

30 params: parameters for the sweeper 

31 level (pySDC.Level.level): the level that uses this sweeper 

32 """ 

33 self.logger = logging.getLogger('sweeper') 

34 

35 if 'comm' not in params.keys(): 

36 params['comm'] = MPI.COMM_WORLD 

37 self.logger.debug('Using MPI.COMM_WORLD for the communicator because none was supplied in the params.') 

38 super().__init__(params, level) 

39 

40 if self.params.comm.size != self.coll.num_nodes: 

41 raise NotImplementedError( 

42 f'The communicator in the {type(self).__name__} sweeper needs to have one rank for each node as of now! That means we need {self.coll.num_nodes} nodes, but got {self.params.comm.size} processes.' 

43 ) 

44 

45 @property 

46 def comm(self): 

47 return self.params.comm 

48 

49 @property 

50 def rank(self): 

51 return self.comm.rank 

52 

53 def compute_end_point(self): 

54 """ 

55 Compute u at the right point of the interval 

56 

57 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False 

58 

59 Returns: 

60 None 

61 """ 

62 

63 L = self.level 

64 P = L.prob 

65 

66 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy 

67 if self.coll.right_is_node and not self.params.do_coll_update: 

68 # a copy is sufficient 

69 root = self.comm.Get_size() - 1 

70 if self.comm.rank == root: 

71 L.uend = P.dtype_u(L.u[-1]) 

72 else: 

73 L.uend = P.dtype_u(L.u[0]) 

74 self.comm.Bcast(L.uend, root=root) 

75 else: 

76 raise NotImplementedError('require last node to be identical with right interval boundary') 

77 

78 return None 

79 

80 def compute_residual(self, stage=None): 

81 """ 

82 Computation of the residual using the collocation matrix Q 

83 

84 Args: 

85 stage (str): The current stage of the step the level belongs to 

86 """ 

87 

88 L = self.level 

89 

90 # Check if we want to skip the residual computation to gain performance 

91 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual! 

92 if stage in self.params.skip_residual_computation: 

93 L.status.residual = 0.0 if L.status.residual is None else L.status.residual 

94 return None 

95 

96 # compute the residual for each node 

97 

98 # build QF(u) 

99 res = self.integrate(last_only=L.params.residual_type[:4] == 'last') 

100 res += L.u[0] - L.u[self.rank + 1] 

101 # add tau if associated 

102 if L.tau[self.rank] is not None: 

103 res += L.tau[self.rank] 

104 # use abs function from data type here 

105 res_norm = abs(res) 

106 

107 # find maximal residual over the nodes 

108 if L.params.residual_type == 'full_abs': 

109 L.status.residual = self.comm.allreduce(res_norm, op=MPI.MAX) 

110 elif L.params.residual_type == 'last_abs': 

111 L.status.residual = self.comm.bcast(res_norm, root=self.comm.size - 1) 

112 elif L.params.residual_type == 'full_rel': 

113 L.status.residual = self.comm.allreduce(res_norm / abs(L.u[0]), op=MPI.MAX) 

114 elif L.params.residual_type == 'last_rel': 

115 L.status.residual = self.comm.bcast(res_norm / abs(L.u[0]), root=self.comm.size - 1) 

116 else: 

117 raise NotImplementedError(f'residual type \"{L.params.residual_type}\" not implemented!') 

118 

119 # indicate that the residual has seen the new values 

120 L.status.updated = False 

121 

122 return None 

123 

124 def predict(self): 

125 """ 

126 Predictor to fill values at nodes before first sweep 

127 

128 Default prediction for the sweepers, only copies the values to all collocation nodes 

129 and evaluates the RHS of the ODE there 

130 """ 

131 

132 L = self.level 

133 P = L.prob 

134 

135 # evaluate RHS at left point 

136 L.f[0] = P.eval_f(L.u[0], L.time) 

137 

138 m = self.rank 

139 

140 if self.params.initial_guess == 'spread': 

141 # copy u[0] to all collocation nodes, evaluate RHS 

142 L.u[m + 1] = P.dtype_u(L.u[0]) 

143 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

144 elif self.params.initial_guess == 'copy': 

145 # copy u[0] and RHS evaluation to all collocation nodes 

146 L.u[m + 1] = P.dtype_u(L.u[0]) 

147 L.f[m + 1] = P.dtype_f(L.f[0]) 

148 elif self.params.initial_guess == 'zero': 

149 # zeros solution for u and RHS 

150 L.u[m + 1] = P.dtype_u(init=P.init, val=0.0) 

151 L.f[m + 1] = P.dtype_f(init=P.init, val=0.0) 

152 else: 

153 raise ParameterError(f'initial_guess option {self.params.initial_guess} not implemented') 

154 

155 # indicate that this level is now ready for sweeps 

156 L.status.unlocked = True 

157 L.status.updated = True 

158 

159 def communicate_tau_correction_for_full_interval(self): 

160 L = self.level 

161 P = L.prob 

162 if self.rank < self.comm.size - 1: 

163 L.tau[-1] = P.u_init 

164 self.comm.Bcast(L.tau[-1], root=self.comm.size - 1) 

165 

166 

167class generic_implicit_MPI(SweeperMPI, generic_implicit): 

168 """ 

169 Generic implicit sweeper parallelized across the nodes. 

170 Please supply a communicator as `comm` to the parameters! 

171 

172 Attributes: 

173 rank (int): MPI rank 

174 """ 

175 

176 def integrate(self, last_only=False): 

177 """ 

178 Integrates the right-hand side 

179 

180 Args: 

181 last_only (bool): Integrate only the last node for the residual or all of them 

182 

183 Returns: 

184 list of dtype_u: containing the integral as values 

185 """ 

186 L = self.level 

187 P = L.prob 

188 

189 me = P.dtype_u(P.init, val=0.0) 

190 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes): 

191 recvBuf = me if m == self.rank else None 

192 self.comm.Reduce( 

193 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * L.f[self.rank + 1], recvBuf, root=m, op=MPI.SUM 

194 ) 

195 

196 return me 

197 

198 def update_nodes(self): 

199 """ 

200 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes 

201 

202 Returns: 

203 None 

204 """ 

205 

206 L = self.level 

207 P = L.prob 

208 

209 # only if the level has been touched before 

210 assert L.status.unlocked 

211 

212 # gather all terms which are known already (e.g. from the previous iteration) 

213 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau 

214 

215 # get QF(u^k) 

216 rhs = self.integrate() 

217 

218 rhs -= L.dt * self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1] 

219 

220 # add initial value 

221 rhs += L.u[0] 

222 # add tau if associated 

223 if L.tau[self.rank] is not None: 

224 rhs += L.tau[self.rank] 

225 

226 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1) 

227 

228 # implicit solve with prefactor stemming from the diagonal of Qd 

229 L.u[self.rank + 1] = P.solve_system( 

230 rhs, 

231 L.dt * self.QI[self.rank + 1, self.rank + 1], 

232 L.u[self.rank + 1], 

233 L.time + L.dt * self.coll.nodes[self.rank], 

234 ) 

235 # update function values 

236 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank]) 

237 

238 # indicate presence of new values at this level 

239 L.status.updated = True 

240 

241 return None 

242 

243 def compute_end_point(self): 

244 """ 

245 Compute u at the right point of the interval 

246 

247 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False 

248 

249 Returns: 

250 None 

251 """ 

252 

253 L = self.level 

254 P = L.prob 

255 

256 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy 

257 if self.coll.right_is_node and not self.params.do_coll_update: 

258 super().compute_end_point() 

259 else: 

260 L.uend = P.dtype_u(L.u[0]) 

261 self.comm.Allreduce(L.dt * self.coll.weights[self.rank] * L.f[self.rank + 1], L.uend, op=MPI.SUM) 

262 L.uend += L.u[0] 

263 

264 # add up tau correction of the full interval (last entry) 

265 if L.tau[self.rank] is not None: 

266 self.communicate_tau_correction_for_full_interval() 

267 L.uend += L.tau[-1] 

268 return None