Coverage for pySDC/implementations/sweeper_classes/imex_1st_order_MPI.py: 95%
41 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 17:10 +0000
1from mpi4py import MPI
2from pySDC.implementations.sweeper_classes.generic_implicit_MPI import SweeperMPI
3from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
6class imex_1st_order_MPI(SweeperMPI, imex_1st_order):
7 def __init__(self, params):
8 super().__init__(params)
9 assert (
10 self.params.QE == 'PIC'
11 ), f"Only Picard is implemented for explicit precondioner so far in {type(self).__name__}! You chose \"{self.params.QE}\""
13 def integrate(self, last_only=False):
14 """
15 Integrates the right-hand side (here impl + expl)
17 Args:
18 last_only (bool): Integrate only the last node for the residual or all of them
20 Returns:
21 list of dtype_u: containing the integral as values
22 """
24 L = self.level
25 P = L.prob
27 me = P.dtype_u(P.init, val=0.0)
28 for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes):
29 recvBuf = me if m == self.rank else None
30 self.comm.Reduce(
31 L.dt * self.coll.Qmat[m + 1, self.rank + 1] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl),
32 recvBuf,
33 root=m,
34 op=MPI.SUM,
35 )
37 return me
39 def update_nodes(self):
40 """
41 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
43 Returns:
44 None
45 """
47 L = self.level
48 P = L.prob
50 # only if the level has been touched before
51 assert L.status.unlocked
53 # get number of collocation nodes for easier access
55 # gather all terms which are known already (e.g. from the previous iteration)
56 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
58 # get QF(u^k)
59 rhs = self.integrate()
61 # subtract QdF(u^k)
62 rhs -= L.dt * (self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1].impl)
64 # add initial conditions
65 rhs += L.u[0]
66 # add tau if associated
67 if L.tau[self.rank] is not None:
68 rhs += L.tau[self.rank]
70 # implicit solve with prefactor stemming from the diagonal of Qd
71 L.u[self.rank + 1] = P.solve_system(
72 rhs,
73 L.dt * self.QI[self.rank + 1, self.rank + 1],
74 L.u[self.rank + 1],
75 L.time + L.dt * self.coll.nodes[self.rank],
76 )
77 # update function values
78 L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank])
80 # indicate presence of new values at this level
81 L.status.updated = True
83 return None
85 def compute_end_point(self):
86 """
87 Compute u at the right point of the interval
89 Returns:
90 None
91 """
93 L = self.level
94 P = L.prob
95 L.uend = P.dtype_u(P.init, val=0.0)
97 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
98 if self.coll.right_is_node and not self.params.do_coll_update:
99 super().compute_end_point()
100 else:
101 L.uend = P.dtype_u(L.u[0])
102 self.comm.Allreduce(
103 L.dt * self.coll.weights[self.rank] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl),
104 L.uend,
105 op=MPI.SUM,
106 )
107 L.uend += L.u[0]
109 # add up tau correction of the full interval (last entry)
110 if L.tau[self.rank] is not None:
111 self.communicate_tau_correction_for_full_interval()
112 L.uend += L.tau[-1]
113 return None