Coverage for pySDC/helpers/firedrake_ensemble_communicator.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-20 10:09 +0000
1from mpi4py import MPI
2import firedrake as fd
3import numpy as np
6class FiredrakeEnsembleCommunicator:
7 """
8 Ensemble communicator for performing multiple similar distributed simulations with Firedrake, see https://www.firedrakeproject.org/firedrake/parallelism.html
9 This is intended to do space-time parallelism in pySDC.
10 This class wraps the time communicator. All requests that are not overloaded are passed to the time communicator. For instance, `ensemble.rank` will return the rank in the time communicator.
11 Some operations are overloaded to use the interface of the MPI communicator but handles communication with the ensemble communicator instead.
12 """
14 def __init__(self, comm, space_size):
15 """
16 Args:
17 comm (MPI.Intracomm): MPI communicator, which will be split into time and space communicators
18 space_size (int): Size of the spatial communicators
20 Attributes:
21 ensemble (firedrake.Ensemble): Ensemble communicator
22 """
23 self.ensemble = fd.Ensemble(comm, space_size)
24 self.comm_wold = comm
26 def Split(self, *args, **kwargs):
27 return FiredrakeEnsembleCommunicator(self.comm_wold.Split(*args, **kwargs), space_size=self.space_comm.size)
29 @property
30 def space_comm(self):
31 return self.ensemble.comm
33 @property
34 def time_comm(self):
35 return self.ensemble.ensemble_comm
37 def __getattr__(self, name):
38 return getattr(self.time_comm, name)
40 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
41 if type(sendbuf) in [np.ndarray]:
42 self.ensemble.ensemble_comm.Reduce(sendbuf, recvbuf, op, root)
43 else:
44 assert op == MPI.SUM
45 self.ensemble.reduce(sendbuf, recvbuf, root=root)
47 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
48 if type(sendbuf) in [np.ndarray]:
49 self.ensemble.ensemble_comm.Allreduce(sendbuf, recvbuf, op)
50 else:
51 assert op == MPI.SUM
52 self.ensemble.allreduce(sendbuf, recvbuf)
54 def Bcast(self, buf, root=0):
55 if type(buf) in [np.ndarray]:
56 self.ensemble.ensemble_comm.Bcast(buf, root)
57 else:
58 self.ensemble.bcast(buf, root=root)
60 def Irecv(self, buf, source, tag=MPI.ANY_TAG):
61 if type(buf) in [np.ndarray, list]:
62 return self.ensemble.ensemble_comm.Irecv(buf=buf, source=source, tag=tag)
63 return self.ensemble.irecv(buf, source, tag=tag)[0]
65 def Isend(self, buf, dest, tag=MPI.ANY_TAG):
66 if type(buf) in [np.ndarray, list]:
67 return self.ensemble.ensemble_comm.Isend(buf=buf, dest=dest, tag=tag)
68 return self.ensemble.isend(buf, dest, tag=tag)[0]
70 def Free(self):
71 del self
74def get_ensemble(comm, space_size):
75 return fd.Ensemble(comm, space_size)