Coverage for pySDC/helpers/firedrake_ensemble_communicator.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 10:09 +0000

1from mpi4py import MPI 

2import firedrake as fd 

3import numpy as np 

4 

5 

6class FiredrakeEnsembleCommunicator: 

7 """ 

8 Ensemble communicator for performing multiple similar distributed simulations with Firedrake, see https://www.firedrakeproject.org/firedrake/parallelism.html 

9 This is intended to do space-time parallelism in pySDC. 

10 This class wraps the time communicator. All requests that are not overloaded are passed to the time communicator. For instance, `ensemble.rank` will return the rank in the time communicator. 

11 Some operations are overloaded to use the interface of the MPI communicator but handles communication with the ensemble communicator instead. 

12 """ 

13 

14 def __init__(self, comm, space_size): 

15 """ 

16 Args: 

17 comm (MPI.Intracomm): MPI communicator, which will be split into time and space communicators 

18 space_size (int): Size of the spatial communicators 

19 

20 Attributes: 

21 ensemble (firedrake.Ensemble): Ensemble communicator 

22 """ 

23 self.ensemble = fd.Ensemble(comm, space_size) 

24 self.comm_wold = comm 

25 

26 def Split(self, *args, **kwargs): 

27 return FiredrakeEnsembleCommunicator(self.comm_wold.Split(*args, **kwargs), space_size=self.space_comm.size) 

28 

29 @property 

30 def space_comm(self): 

31 return self.ensemble.comm 

32 

33 @property 

34 def time_comm(self): 

35 return self.ensemble.ensemble_comm 

36 

37 def __getattr__(self, name): 

38 return getattr(self.time_comm, name) 

39 

40 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): 

41 if type(sendbuf) in [np.ndarray]: 

42 self.ensemble.ensemble_comm.Reduce(sendbuf, recvbuf, op, root) 

43 else: 

44 assert op == MPI.SUM 

45 self.ensemble.reduce(sendbuf, recvbuf, root=root) 

46 

47 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM): 

48 if type(sendbuf) in [np.ndarray]: 

49 self.ensemble.ensemble_comm.Allreduce(sendbuf, recvbuf, op) 

50 else: 

51 assert op == MPI.SUM 

52 self.ensemble.allreduce(sendbuf, recvbuf) 

53 

54 def Bcast(self, buf, root=0): 

55 if type(buf) in [np.ndarray]: 

56 self.ensemble.ensemble_comm.Bcast(buf, root) 

57 else: 

58 self.ensemble.bcast(buf, root=root) 

59 

60 def Irecv(self, buf, source, tag=MPI.ANY_TAG): 

61 if type(buf) in [np.ndarray, list]: 

62 return self.ensemble.ensemble_comm.Irecv(buf=buf, source=source, tag=tag) 

63 return self.ensemble.irecv(buf, source, tag=tag)[0] 

64 

65 def Isend(self, buf, dest, tag=MPI.ANY_TAG): 

66 if type(buf) in [np.ndarray, list]: 

67 return self.ensemble.ensemble_comm.Isend(buf=buf, dest=dest, tag=tag) 

68 return self.ensemble.isend(buf, dest, tag=tag)[0] 

69 

70 def Free(self): 

71 del self 

72 

73 

74def get_ensemble(comm, space_size): 

75 return fd.Ensemble(comm, space_size)