Coverage for pySDC/helpers/NCCL_communicator.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1from mpi4py import MPI 

2from cupy.cuda import nccl 

3import cupy as cp 

4import numpy as np 

5 

6 

7class NCCLComm(object): 

8 """ 

9 Wraps an MPI communicator and performs some calls to NCCL functions instead. 

10 """ 

11 

12 def __init__(self, comm): 

13 """ 

14 Args: 

15 comm (mpi4py.Intracomm): MPI communicator 

16 """ 

17 self.commMPI = comm 

18 

19 uid = comm.bcast(nccl.get_unique_id(), root=0) 

20 self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank) 

21 

22 def __getattr__(self, name): 

23 """ 

24 Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator. 

25 When performing any operations that depend on data, we have to synchronize host and device beforehand. 

26 

27 Args: 

28 Name (str): Name of the requested attribute 

29 """ 

30 if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']: 

31 cp.cuda.get_current_stream().synchronize() 

32 return getattr(self.commMPI, name) 

33 

34 @staticmethod 

35 def get_dtype(data): 

36 """ 

37 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. 

38 """ 

39 dtype = data.dtype 

40 if dtype in [np.dtype('float32'), np.dtype('complex64')]: 

41 return nccl.NCCL_FLOAT32 

42 elif dtype in [np.dtype('float64'), np.dtype('complex128')]: 

43 return nccl.NCCL_FLOAT64 

44 elif dtype in [np.dtype('int32')]: 

45 return nccl.NCCL_INT32 

46 elif dtype in [np.dtype('int64')]: 

47 return nccl.NCCL_INT64 

48 else: 

49 raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!') 

50 

51 @staticmethod 

52 def get_count(data): 

53 """ 

54 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex. 

55 """ 

56 if cp.iscomplexobj(data): 

57 return data.size * 2 

58 else: 

59 return data.size 

60 

61 def get_op(self, MPI_op): 

62 if MPI_op == MPI.SUM: 

63 return nccl.NCCL_SUM 

64 elif MPI_op == MPI.PROD: 

65 return nccl.NCCL_PROD 

66 elif MPI_op == MPI.MAX: 

67 return nccl.NCCL_MAX 

68 elif MPI_op == MPI.MIN: 

69 return nccl.NCCL_MIN 

70 else: 

71 raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!') 

72 

73 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0): 

74 dtype = self.get_dtype(sendbuf) 

75 count = self.get_count(sendbuf) 

76 op = self.get_op(op) 

77 recvbuf = cp.empty(1) if recvbuf is None else recvbuf 

78 stream = cp.cuda.get_current_stream() 

79 

80 self.commNCCL.reduce( 

81 sendbuf=sendbuf.data.ptr, 

82 recvbuf=recvbuf.data.ptr, 

83 count=count, 

84 datatype=dtype, 

85 op=op, 

86 root=root, 

87 stream=stream.ptr, 

88 ) 

89 

90 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM): 

91 dtype = self.get_dtype(sendbuf) 

92 count = self.get_count(sendbuf) 

93 op = self.get_op(op) 

94 stream = cp.cuda.get_current_stream() 

95 

96 self.commNCCL.allReduce( 

97 sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr 

98 )