Coverage for pySDC/helpers/NCCL_communicator.py: 0%
62 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1from mpi4py import MPI
2from cupy.cuda import nccl
3import cupy as cp
4import numpy as np
7class NCCLComm(object):
8 """
9 Wraps an MPI communicator and performs some calls to NCCL functions instead.
10 """
12 def __init__(self, comm):
13 """
14 Args:
15 comm (mpi4py.Intracomm): MPI communicator
16 """
17 self.commMPI = comm
19 uid = comm.bcast(nccl.get_unique_id(), root=0)
20 self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank)
22 def __getattr__(self, name):
23 """
24 Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator.
25 When performing any operations that depend on data, we have to synchronize host and device beforehand.
27 Args:
28 Name (str): Name of the requested attribute
29 """
30 if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
31 cp.cuda.get_current_stream().synchronize()
33 return getattr(self.commMPI, name)
35 @staticmethod
36 def get_dtype(data):
37 """
38 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
39 """
40 dtype = data.dtype
41 if dtype in [np.dtype('float32'), np.dtype('complex64')]:
42 return nccl.NCCL_FLOAT32
43 elif dtype in [np.dtype('float64'), np.dtype('complex128')]:
44 return nccl.NCCL_FLOAT64
45 elif dtype in [np.dtype('int32')]:
46 return nccl.NCCL_INT32
47 elif dtype in [np.dtype('int64')]:
48 return nccl.NCCL_INT64
49 else:
50 raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!')
52 @staticmethod
53 def get_count(data):
54 """
55 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
56 """
57 if cp.iscomplexobj(data):
58 return data.size * 2
59 else:
60 return data.size
62 def get_op(self, MPI_op):
63 if MPI_op == MPI.SUM:
64 return nccl.NCCL_SUM
65 elif MPI_op == MPI.PROD:
66 return nccl.NCCL_PROD
67 elif MPI_op == MPI.MAX:
68 return nccl.NCCL_MAX
69 elif MPI_op == MPI.MIN:
70 return nccl.NCCL_MIN
71 else:
72 raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')
74 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
75 if not hasattr(sendbuf.data, 'ptr'):
76 return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
78 dtype = self.get_dtype(sendbuf)
79 count = self.get_count(sendbuf)
80 op = self.get_op(op)
81 recvbuf = cp.empty(1) if recvbuf is None else recvbuf
82 stream = cp.cuda.get_current_stream()
84 self.commNCCL.reduce(
85 sendbuf=sendbuf.data.ptr,
86 recvbuf=recvbuf.data.ptr,
87 count=count,
88 datatype=dtype,
89 op=op,
90 root=root,
91 stream=stream.ptr,
92 )
94 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
95 if not hasattr(sendbuf.data, 'ptr'):
96 return self.commMPI.Allreduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op)
98 dtype = self.get_dtype(sendbuf)
99 count = self.get_count(sendbuf)
100 op = self.get_op(op)
101 stream = cp.cuda.get_current_stream()
103 self.commNCCL.allReduce(
104 sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr
105 )
107 def Bcast(self, buf, root=0):
108 if not hasattr(buf.data, 'ptr'):
109 return self.commMPI.Bcast(buf=buf, root=root)
111 dtype = self.get_dtype(buf)
112 count = self.get_count(buf)
113 stream = cp.cuda.get_current_stream()
115 self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)