Coverage for pySDC/helpers/NCCL_communicator.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1from mpi4py import MPI 

2from cupy.cuda import nccl 

3import cupy as cp 

4import numpy as np 

5 

6 

7class NCCLComm(object): 

8 """ 

9 Wraps an MPI communicator and performs some calls to NCCL functions instead. 

10 """ 

11 

12 def __init__(self, comm): 

13 """ 

14 Args: 

15 comm (mpi4py.Intracomm): MPI communicator 

16 """ 

17 self.commMPI = comm 

18 

19 uid = comm.bcast(nccl.get_unique_id(), root=0) 

20 self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank) 

21 

22 def __getattr__(self, name): 

23 """ 

24 Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator. 

25 When performing any operations that depend on data, we have to synchronize host and device beforehand. 

26 

27 Args: 

28 Name (str): Name of the requested attribute 

29 """ 

30 if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']: 

31 cp.cuda.get_current_stream().synchronize() 

32 

33 return getattr(self.commMPI, name) 

34 

35 @staticmethod 

36 def get_dtype(data): 

37 """ 

38 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. 

39 """ 

40 dtype = data.dtype 

41 if dtype in [np.dtype('float32'), np.dtype('complex64')]: 

42 return nccl.NCCL_FLOAT32 

43 elif dtype in [np.dtype('float64'), np.dtype('complex128')]: 

44 return nccl.NCCL_FLOAT64 

45 elif dtype in [np.dtype('int32')]: 

46 return nccl.NCCL_INT32 

47 elif dtype in [np.dtype('int64')]: 

48 return nccl.NCCL_INT64 

49 else: 

50 raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!') 

51 

52 @staticmethod 

53 def get_count(data): 

54 """ 

55 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. 

56 """ 

57 if cp.iscomplexobj(data): 

58 return data.size * 2 

59 else: 

60 return data.size 

61 

62 def get_op(self, MPI_op): 

63 if MPI_op == MPI.SUM: 

64 return nccl.NCCL_SUM 

65 elif MPI_op == MPI.PROD: 

66 return nccl.NCCL_PROD 

67 elif MPI_op == MPI.MAX: 

68 return nccl.NCCL_MAX 

69 elif MPI_op == MPI.MIN: 

70 return nccl.NCCL_MIN 

71 else: 

72 raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!') 

73 

74 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): 

75 if not hasattr(sendbuf.data, 'ptr'): 

76 return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root) 

77 

78 dtype = self.get_dtype(sendbuf) 

79 count = self.get_count(sendbuf) 

80 op = self.get_op(op) 

81 recvbuf = cp.empty(1) if recvbuf is None else recvbuf 

82 stream = cp.cuda.get_current_stream() 

83 

84 self.commNCCL.reduce( 

85 sendbuf=sendbuf.data.ptr, 

86 recvbuf=recvbuf.data.ptr, 

87 count=count, 

88 datatype=dtype, 

89 op=op, 

90 root=root, 

91 stream=stream.ptr, 

92 ) 

93 

94 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM): 

95 if not hasattr(sendbuf.data, 'ptr'): 

96 return self.commMPI.Allreduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op) 

97 

98 dtype = self.get_dtype(sendbuf) 

99 count = self.get_count(sendbuf) 

100 op = self.get_op(op) 

101 stream = cp.cuda.get_current_stream() 

102 

103 self.commNCCL.allReduce( 

104 sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr 

105 ) 

106 

107 def Bcast(self, buf, root=0): 

108 if not hasattr(buf.data, 'ptr'): 

109 return self.commMPI.Bcast(buf=buf, root=root) 

110 

111 dtype = self.get_dtype(buf) 

112 count = self.get_count(buf) 

113 stream = cp.cuda.get_current_stream() 

114 

115 self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)