Coverage for pySDC/helpers/NCCL_communicator.py: 0%
51 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1from mpi4py import MPI
2from cupy.cuda import nccl
3import cupy as cp
4import numpy as np
7class NCCLComm(object):
8 """
9 Wraps an MPI communicator and performs some calls to NCCL functions instead.
10 """
12 def __init__(self, comm):
13 """
14 Args:
15 comm (mpi4py.Intracomm): MPI communicator
16 """
17 self.commMPI = comm
19 uid = comm.bcast(nccl.get_unique_id(), root=0)
20 self.commNCCL = nccl.NcclCommunicator(comm.size, uid, comm.rank)
22 def __getattr__(self, name):
23 """
24 Pass calls that are not explicitly overridden by NCCL functionality on to the MPI communicator.
25 When performing any operations that depend on data, we have to synchronize host and device beforehand.
27 Args:
28 Name (str): Name of the requested attribute
29 """
30 if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
31 cp.cuda.get_current_stream().synchronize()
32 return getattr(self.commMPI, name)
34 @staticmethod
35 def get_dtype(data):
36 """
37 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
38 """
39 dtype = data.dtype
40 if dtype in [np.dtype('float32'), np.dtype('complex64')]:
41 return nccl.NCCL_FLOAT32
42 elif dtype in [np.dtype('float64'), np.dtype('complex128')]:
43 return nccl.NCCL_FLOAT64
44 elif dtype in [np.dtype('int32')]:
45 return nccl.NCCL_INT32
46 elif dtype in [np.dtype('int64')]:
47 return nccl.NCCL_INT64
48 else:
49 raise NotImplementedError(f'Don\'t know what NCCL dtype to use to send data of dtype {data.dtype}!')
51 @staticmethod
52 def get_count(data):
53 """
54 As NCCL doesn't support complex numbers, we have to act as if we're sending two real numbers if using complex.
55 """
56 if cp.iscomplexobj(data):
57 return data.size * 2
58 else:
59 return data.size
61 def get_op(self, MPI_op):
62 if MPI_op == MPI.SUM:
63 return nccl.NCCL_SUM
64 elif MPI_op == MPI.PROD:
65 return nccl.NCCL_PROD
66 elif MPI_op == MPI.MAX:
67 return nccl.NCCL_MAX
68 elif MPI_op == MPI.MIN:
69 return nccl.NCCL_MIN
70 else:
71 raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')
73 def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
74 dtype = self.get_dtype(sendbuf)
75 count = self.get_count(sendbuf)
76 op = self.get_op(op)
77 recvbuf = cp.empty(1) if recvbuf is None else recvbuf
78 stream = cp.cuda.get_current_stream()
80 self.commNCCL.reduce(
81 sendbuf=sendbuf.data.ptr,
82 recvbuf=recvbuf.data.ptr,
83 count=count,
84 datatype=dtype,
85 op=op,
86 root=root,
87 stream=stream.ptr,
88 )
90 def Allreduce(self, sendbuf, recvbuf, op=MPI.SUM):
91 dtype = self.get_dtype(sendbuf)
92 count = self.get_count(sendbuf)
93 op = self.get_op(op)
94 stream = cp.cuda.get_current_stream()
96 self.commNCCL.allReduce(
97 sendbuf=sendbuf.data.ptr, recvbuf=recvbuf.data.ptr, count=count, datatype=dtype, op=op, stream=stream.ptr
98 )