Source code for implementations.sweeper_classes.imex_1st_order_MPI

from mpi4py import MPI
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import SweeperMPI
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order


[docs] class imex_1st_order_MPI(SweeperMPI, imex_1st_order): def __init__(self, params): super().__init__(params) assert ( self.params.QE == 'PIC' ), f"Only Picard is implemented for explicit precondioner so far in {type(self).__name__}! You chose \"{self.params.QE}\""
[docs] def integrate(self, last_only=False): """ Integrates the right-hand side (here impl + expl) Args: last_only (bool): Integrate only the last node for the residual or all of them Returns: list of dtype_u: containing the integral as values """ L = self.level P = L.prob me = P.dtype_u(P.init, val=0.0) for m in [self.coll.num_nodes - 1] if last_only else range(self.coll.num_nodes): recvBuf = me if m == self.rank else None self.comm.Reduce( L.dt * self.coll.Qmat[m + 1, self.rank + 1] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl), recvBuf, root=m, op=MPI.SUM, ) return me
[docs] def update_nodes(self): """ Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes Returns: None """ L = self.level P = L.prob # only if the level has been touched before assert L.status.unlocked # get number of collocation nodes for easier access # gather all terms which are known already (e.g. from the previous iteration) # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau # get QF(u^k) rhs = self.integrate() # subtract QdF(u^k) rhs -= L.dt * (self.QI[self.rank + 1, self.rank + 1] * L.f[self.rank + 1].impl) # add initial conditions rhs += L.u[0] # add tau if associated if L.tau[self.rank] is not None: rhs += L.tau[self.rank] # implicit solve with prefactor stemming from the diagonal of Qd L.u[self.rank + 1] = P.solve_system( rhs, L.dt * self.QI[self.rank + 1, self.rank + 1], L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank], ) # update function values L.f[self.rank + 1] = P.eval_f(L.u[self.rank + 1], L.time + L.dt * self.coll.nodes[self.rank]) # indicate presence of new values at this level L.status.updated = True return None
[docs] def compute_end_point(self): """ Compute u at the right point of the interval Returns: None """ L = self.level P = L.prob # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy if self.coll.right_is_node and not self.params.do_coll_update: super().compute_end_point() else: L.uend = P.dtype_u(L.u[0]) self.comm.Allreduce( L.dt * self.coll.weights[self.rank] * (L.f[self.rank + 1].impl + L.f[self.rank + 1].expl), L.uend, op=MPI.SUM, ) L.uend += L.u[0] # add up tau correction of the full interval (last entry) if L.tau[self.rank] is not None: self.communicate_tau_correction_for_full_interval() L.uend += L.tau[-1] return None