Coverage for pySDC/helpers/NCCL_communicator.py: 22%
81 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1from mpi4py import MPI
2from cupy.cuda import nccl
3import cupy as cp
4import numpy as np
7class NCCLComm(object):
8 """
9 Wraps an MPI communicator and performs some calls to NCCL functions instead.
10 """
12 def __init__(self, comm):
13 """
14 Args:
15 comm (mpi4py.Intracomm): MPI communicator
16 """
17 self.commMPI = comm
19 uid = comm.bcast(nccl.get_unique_id(), root=0)
20 self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank)
22 def __getattr__(self, name):
23 """
24 Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator.
25 When performing any operations that depend on data, we have to synchronize host and device beforehand.
27 Args:
28 Name (str): Name of the requested attribute
29 """
30 if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']:
31 cp.cuda.get_current_stream().synchronize()
33 return getattr(self.commMPI, name)
35 @staticmethod
36 def get_dtype(data):
37 """
38 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
39 """
40 dtype = data.dtype
41 if dtype in [np.dtype('float32'), np.dtype('complex64')]:
42 return nccl.NCCL_FLOAT32
43 elif dtype in [np.dtype('float64'), np.dtype('complex128')]:
44 return nccl.NCCL_FLOAT64
45 elif dtype in [np.dtype('int32')]:
46 return nccl.NCCL_INT32
47 elif dtype in [np.dtype('int64')]:
48 return nccl.NCCL_INT64
49 else:
50 raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!')
52 @staticmethod
53 def get_count(data):
54 """
55 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
56 """
57 if cp.iscomplexobj(data):
58 return data.size * 2
59 else:
60 return data.size
62 def get_op(self, MPI_op):
63 if MPI_op == MPI.SUM:
64 return nccl.NCCL_SUM
65 elif MPI_op == MPI.PROD:
66 return nccl.NCCL_PROD
67 elif MPI_op == MPI.MAX:
68 return nccl.NCCL_MAX
69 elif MPI_op == MPI.MIN:
70 return nccl.NCCL_MIN
71 else:
72 raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')
74 def reduce(self, sendobj, op=MPI.SUM, root=0):
75 sync = False
76 if hasattr(sendobj, 'data'):
77 if hasattr(sendobj.data, 'ptr'):
78 sync = True
79 if sync:
80 cp.cuda.Device().synchronize()
82 return self.commMPI.reduce(sendobj, op=op, root=root)
84 def allreduce(self, sendobj, op=MPI.SUM):
85 sync = False
86 if hasattr(sendobj, 'data'):
87 if hasattr(sendobj.data, 'ptr'):
88 sync = True
89 if sync:
90 cp.cuda.Device().synchronize()
92 return self.commMPI.allreduce(sendobj, op=op)
94 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
95 if not hasattr(sendbuf.data, 'ptr'):
96 return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
98 dtype = self.get_dtype(sendbuf)
99 count = self.get_count(sendbuf)
100 op = self.get_op(op)
101 recvbuf = cp.empty(1) if recvbuf is None else recvbuf
102 stream = cp.cuda.get_current_stream()
104 self.commNCCL.reduce(
105 sendbuf=sendbuf.data.ptr,
106 recvbuf=recvbuf.data.ptr,
107 count=count,
108 datatype=dtype,
109 op=op,
110 root=root,
111 stream=stream.ptr,
112 )
114 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
115 if not hasattr(sendbuf.data, 'ptr'):
116 return self.commMPI.Allreduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op)
118 dtype = self.get_dtype(sendbuf)
119 count = self.get_count(sendbuf)
120 op = self.get_op(op)
121 stream = cp.cuda.get_current_stream()
123 self.commNCCL.allReduce(
124 sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr
125 )
127 def Bcast(self, buf, root=0):
128 if not hasattr(buf.data, 'ptr'):
129 return self.commMPI.Bcast(buf=buf, root=root)
131 dtype = self.get_dtype(buf)
132 count = self.get_count(buf)
133 stream = cp.cuda.get_current_stream()
135 self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)
137 def Barrier(self):
138 cp.cuda.get_current_stream().synchronize()
139 self.commMPI.Barrier()