Coverage for pySDC/implementations/datatype_classes/cupy_mesh.py: 81%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 14:51 +0000

1import cupy as cp 

2 

3try: 

4 from mpi4py import MPI 

5except ImportError: 

6 MPI = None 

7 

8try: 

9 from pySDC.helpers.NCCL_communicator import NCCLComm 

10except ImportError: 

11 NCCLComm = None 

12 

13 

14class cupy_mesh(cp.ndarray): 

15 """ 

16 CuPy-based datatype for serial or parallel meshes. 

17 """ 

18 

19 comm = None 

20 xp = cp 

21 

22 def __new__(cls, init, val=0.0, **kwargs): 

23 """ 

24 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. 

25 

26 Args: 

27 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype 

28 val: value to initialize 

29 

30 Returns: 

31 obj of type mesh 

32 

33 """ 

34 if isinstance(init, cupy_mesh): 

35 obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs) 

36 obj[:] = init[:] 

37 elif ( 

38 isinstance(init, tuple) 

39 and (init[1] is None or isinstance(init[1], MPI.Intracomm) or isinstance(init[1], NCCLComm)) 

40 and isinstance(init[2], cp.dtype) 

41 ): 

42 obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs) 

43 obj.fill(val) 

44 cls.comm = init[1] 

45 else: 

46 raise NotImplementedError(type(init)) 

47 return obj 

48 

49 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): 

50 """ 

51 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs 

52 """ 

53 args = [] 

54 for _, input_ in enumerate(inputs): 

55 if isinstance(input_, cupy_mesh): 

56 args.append(input_.view(cp.ndarray)) 

57 else: 

58 args.append(input_) 

59 results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh) 

60 return results 

61 

62 def __abs__(self): 

63 """ 

64 Overloading the abs operator 

65 

66 Returns: 

67 float: absolute maximum of all mesh values 

68 """ 

69 # take absolute values of the mesh values 

70 local_absval = cp.max(cp.ndarray.__abs__(self)) 

71 

72 if self.comm is not None: 

73 if self.comm.Get_size() > 1: 

74 global_absval = local_absval * 0 

75 if isinstance(self.comm, NCCLComm): 

76 self.comm.Allreduce(sendbuf=local_absval, recvbuf=global_absval, op=MPI.MAX) 

77 else: 

78 global_absval = self.comm.allreduce(sendobj=float(local_absval), op=MPI.MAX) 

79 else: 

80 global_absval = local_absval 

81 else: 

82 global_absval = local_absval 

83 

84 return float(global_absval) 

85 

86 def isend(self, dest=None, tag=None, comm=None): 

87 """ 

88 Routine for sending data forward in time (non-blocking) 

89 

90 Args: 

91 dest (int): target rank 

92 tag (int): communication tag 

93 comm: communicator 

94 

95 Returns: 

96 request handle 

97 """ 

98 return comm.Issend(self[:], dest=dest, tag=tag) 

99 

100 def irecv(self, source=None, tag=None, comm=None): 

101 """ 

102 Routine for receiving in time 

103 

104 Args: 

105 source (int): source rank 

106 tag (int): communication tag 

107 comm: communicator 

108 

109 Returns: 

110 None 

111 """ 

112 return comm.Irecv(self[:], source=source, tag=tag) 

113 

114 def bcast(self, root=None, comm=None): 

115 """ 

116 Routine for broadcasting values 

117 

118 Args: 

119 root (int): process with value to broadcast 

120 comm: communicator 

121 

122 Returns: 

123 broadcasted values 

124 """ 

125 comm.Bcast(self[:], root=root) 

126 return self 

127 

128 

129class CuPyMultiComponentMesh(cupy_mesh): 

130 r""" 

131 Generic mesh with multiple components. 

132 

133 To make a specific multi-component mesh, derive from this class and list the components as strings in the class 

134 attribute ``components``. An example: 

135 

136 ``` 

137 class imex_cupy_mesh(CuPyMultiComponentMesh): 

138 components = ['impl', 'expl'] 

139 ``` 

140 

141 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access 

142 to the components with ``.``. Continuing the above example: 

143 

144 ``` 

145 init = ((100,), None, numpy.dtype('d')) 

146 f = imex_cupy_mesh(init) 

147 f.shape # (2, 100) 

148 f.expl.shape # (100,) 

149 ``` 

150 

151 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the 

152 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view 

153 on to the component if appropriate. Importantly, this means that you cannot name a component like something that 

154 is already an attribute of ``cupy_mesh`` or ``cupy.ndarray`` because this will not result in calls to ``__getattr__``. 

155 

156 There are a couple more things to keep in mind: 

157 - Because a ``CuPyMultiComponentMesh`` is just a ``cupy.ndarray`` with one more dimension, all components must have 

158 the same shape. 

159 - You can use the entire ``CuPyMultiComponentMesh`` like a ``cupy.ndarray`` in operations that accept arrays, but make 

160 sure that you really want to apply the same operation on all components if you do. 

161 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this 

162 leads to all kinds of trouble throughout the code. But here you really cannot get away without. 

163 """ 

164 

165 components = [] 

166 

167 def __new__(cls, init, *args, **kwargs): 

168 if isinstance(init, tuple): 

169 shape = (init[0],) if type(init[0]) is int else init[0] 

170 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs) 

171 else: 

172 obj = super().__new__(cls, init, *args, **kwargs) 

173 

174 return obj 

175 

176 def __getattr__(self, name): 

177 if name in self.components: 

178 if self.shape[0] == len(self.components): 

179 return self[self.components.index(name)].view(cupy_mesh) 

180 else: 

181 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.') 

182 else: 

183 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!") 

184 

185 

186class imex_cupy_mesh(CuPyMultiComponentMesh): 

187 components = ['impl', 'expl'] 

188 

189 

190class comp2_cupy_mesh(CuPyMultiComponentMesh): 

191 components = ['comp1', 'comp2']