Coverage for pySDC/helpers/NCCL_communicator.py: 22%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1from mpi4py import MPI 

2from cupy.cuda import nccl 

3import cupy as cp 

4import numpy as np 

5 

6 

7class NCCLComm(object): 

8 """ 

9 Wraps an MPI communicator and performs some calls to NCCL functions instead. 

10 """ 

11 

12 def __init__(self, comm): 

13 """ 

14 Args: 

15 comm (mpi4py.Intracomm): MPI communicator 

16 """ 

17 self.commMPI = comm 

18 

19 uid = comm.bcast(nccl.get_unique_id(), root=0) 

20 self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank) 

21 

22 def __getattr__(self, name): 

23 """ 

24 Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator. 

25 When performing any operations that depend on data, we have to synchronize host and device beforehand. 

26 

27 Args: 

28 Name (str): Name of the requested attribute 

29 """ 

30 if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']: 

31 cp.cuda.get_current_stream().synchronize() 

32 

33 return getattr(self.commMPI, name) 

34 

35 @staticmethod 

36 def get_dtype(data): 

37 """ 

38 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. 

39 """ 

40 dtype = data.dtype 

41 if dtype in [np.dtype('float32'), np.dtype('complex64')]: 

42 return nccl.NCCL_FLOAT32 

43 elif dtype in [np.dtype('float64'), np.dtype('complex128')]: 

44 return nccl.NCCL_FLOAT64 

45 elif dtype in [np.dtype('int32')]: 

46 return nccl.NCCL_INT32 

47 elif dtype in [np.dtype('int64')]: 

48 return nccl.NCCL_INT64 

49 else: 

50 raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!') 

51 

52 @staticmethod 

53 def get_count(data): 

54 """ 

55 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. 

56 """ 

57 if cp.iscomplexobj(data): 

58 return data.size * 2 

59 else: 

60 return data.size 

61 

62 def get_op(self, MPI_op): 

63 if MPI_op == MPI.SUM: 

64 return nccl.NCCL_SUM 

65 elif MPI_op == MPI.PROD: 

66 return nccl.NCCL_PROD 

67 elif MPI_op == MPI.MAX: 

68 return nccl.NCCL_MAX 

69 elif MPI_op == MPI.MIN: 

70 return nccl.NCCL_MIN 

71 else: 

72 raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!') 

73 

74 def reduce(self, sendobj, op=MPI.SUM, root=0): 

75 sync = False 

76 if hasattr(sendobj, 'data'): 

77 if hasattr(sendobj.data, 'ptr'): 

78 sync = True 

79 if sync: 

80 cp.cuda.Device().synchronize() 

81 

82 return self.commMPI.reduce(sendobj, op=op, root=root) 

83 

84 def allreduce(self, sendobj, op=MPI.SUM): 

85 sync = False 

86 if hasattr(sendobj, 'data'): 

87 if hasattr(sendobj.data, 'ptr'): 

88 sync = True 

89 if sync: 

90 cp.cuda.Device().synchronize() 

91 

92 return self.commMPI.allreduce(sendobj, op=op) 

93 

94 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): 

95 if not hasattr(sendbuf.data, 'ptr'): 

96 return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root) 

97 

98 dtype = self.get_dtype(sendbuf) 

99 count = self.get_count(sendbuf) 

100 op = self.get_op(op) 

101 recvbuf = cp.empty(1) if recvbuf is None else recvbuf 

102 stream = cp.cuda.get_current_stream() 

103 

104 self.commNCCL.reduce( 

105 sendbuf=sendbuf.data.ptr, 

106 recvbuf=recvbuf.data.ptr, 

107 count=count, 

108 datatype=dtype, 

109 op=op, 

110 root=root, 

111 stream=stream.ptr, 

112 ) 

113 

114 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM): 

115 if not hasattr(sendbuf.data, 'ptr'): 

116 return self.commMPI.Allreduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op) 

117 

118 dtype = self.get_dtype(sendbuf) 

119 count = self.get_count(sendbuf) 

120 op = self.get_op(op) 

121 stream = cp.cuda.get_current_stream() 

122 

123 self.commNCCL.allReduce( 

124 sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr 

125 ) 

126 

127 def Bcast(self, buf, root=0): 

128 if not hasattr(buf.data, 'ptr'): 

129 return self.commMPI.Bcast(buf=buf, root=root) 

130 

131 dtype = self.get_dtype(buf) 

132 count = self.get_count(buf) 

133 stream = cp.cuda.get_current_stream() 

134 

135 self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr) 

136 

137 def Barrier(self): 

138 cp.cuda.get_current_stream().synchronize() 

139 self.commMPI.Barrier()