Coverage for pySDC/implementations/sweeper_classes/imex_1st_order_MPI.py: 95%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-20 17:10 +0000

1from mpi4py import MPI 

2from pySDC.implementations.sweeper_classes.generic_implicit_MPI import SweeperMPI 

3from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order 

4 

5 

6class imex_1st_order_MPI(SweeperMPI, imex_1st_order): 

7 def __init__(self, params): 

8 super().__init__(params) 

9 assert ( 

10 self.params.QE == 'PIC' 

11 ), f"Only Picard is implemented for explicit precondioner so far in {type(self).__name__}! You chose \"{self.params.QE}\"" 

12 

13 def integrate(self, last_only=False): 

14 """ 

15 Integrates the right-hand side (here impl + expl) 

16 

17 Args: 

18 last_only (bool): Integrate only the last node for the residual or all of them 

19 

20 Returns: 

21 list of dtype_u: containing the integral as values 

22 """ 

23 

24 L = self.level 

25 P = L.prob 

26 

27 me = P.dtype_u(P.init, val=0.0) 

28 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes): 

29 recvBuf = me if m == self.rank else None 

30 self.comm.Reduce( 

31 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl), 

32 recvBuf, 

33 root=m, 

34 op=MPI.SUM, 

35 ) 

36 

37 return me 

38 

39 def update_nodes(self): 

40 """ 

41 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes 

42 

43 Returns: 

44 None 

45 """ 

46 

47 L = self.level 

48 P = L.prob 

49 

50 # only if the level has been touched before 

51 assert L.status.unlocked 

52 

53 # get number of collocation nodes for easier access 

54 

55 # gather all terms which are known already (e.g. from the previous iteration) 

56 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau 

57 

58 # get QF(u^k) 

59 rhs = self.integrate() 

60 

61 # subtract QdF(u^k) 

62 rhs -= L.dt * (self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1].impl) 

63 

64 # add initial conditions 

65 rhs += L.u[0] 

66 # add tau if associated 

67 if L.tau[self.rank] is not None: 

68 rhs += L.tau[self.rank] 

69 

70 # implicit solve with prefactor stemming from the diagonal of Qd 

71 L.u[self.rank + 1] = P.solve_system( 

72 rhs, 

73 L.dt * self.QI[self.rank + 1, self.rank + 1], 

74 L.u[self.rank + 1], 

75 L.time + L.dt * self.coll.nodes[self.rank], 

76 ) 

77 # update function values 

78 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank]) 

79 

80 # indicate presence of new values at this level 

81 L.status.updated = True 

82 

83 return None 

84 

85 def compute_end_point(self): 

86 """ 

87 Compute u at the right point of the interval 

88 

89 Returns: 

90 None 

91 """ 

92 

93 L = self.level 

94 P = L.prob 

95 L.uend = P.dtype_u(P.init, val=0.0) 

96 

97 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy 

98 if self.coll.right_is_node and not self.params.do_coll_update: 

99 super().compute_end_point() 

100 else: 

101 L.uend = P.dtype_u(L.u[0]) 

102 self.comm.Allreduce( 

103 L.dt * self.coll.weights[self.rank] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl), 

104 L.uend, 

105 op=MPI.SUM, 

106 ) 

107 L.uend += L.u[0] 

108 

109 # add up tau correction of the full interval (last entry) 

110 if L.tau[self.rank] is not None: 

111 self.communicate_tau_correction_for_full_interval() 

112 L.uend += L.tau[-1] 

113 return None