Coverage for pySDC/implementations/datatype_classes/cupy_mesh.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import cupy as cp 

2from pySDC.core.Errors import DataError 

3 

4try: 

5 from mpi4py import MPI 

6except ImportError: 

7 MPI = None 

8 

9 

10class cupy_mesh(cp.ndarray): 

11 """ 

12 CuPy-based datatype for serial or parallel meshes. 

13 """ 

14 

15 def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None): 

16 """ 

17 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. 

18 

19 Args: 

20 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype 

21 val: value to initialize 

22 

23 Returns: 

24 obj of type mesh 

25 

26 """ 

27 if isinstance(init, cupy_mesh): 

28 obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, strides=strides, order=order) 

29 obj[:] = init[:] 

30 obj._comm = init._comm 

31 elif ( 

32 isinstance(init, tuple) 

33 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

34 and isinstance(init[2], cp.dtype) 

35 ): 

36 obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], strides=strides, order=order) 

37 obj.fill(val) 

38 obj._comm = init[1] 

39 else: 

40 raise NotImplementedError(type(init)) 

41 return obj 

42 

43 @property 

44 def comm(self): 

45 """ 

46 Getter for the communicator 

47 """ 

48 return self._comm 

49 

50 def __array_finalize__(self, obj): 

51 """ 

52 Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator. 

53 """ 

54 if obj is None: 

55 return 

56 self._comm = getattr(obj, '_comm', None) 

57 

58 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): 

59 """ 

60 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs 

61 """ 

62 args = [] 

63 comm = None 

64 for _, input_ in enumerate(inputs): 

65 if isinstance(input_, cupy_mesh): 

66 args.append(input_.view(cp.ndarray)) 

67 comm = input_.comm 

68 else: 

69 args.append(input_) 

70 results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh) 

71 if not method == 'reduce': 

72 results._comm = comm 

73 return results 

74 

75 def __abs__(self): 

76 """ 

77 Overloading the abs operator 

78 

79 Returns: 

80 float: absolute maximum of all mesh values 

81 """ 

82 # take absolute values of the mesh values 

83 local_absval = float(cp.amax(cp.ndarray.__abs__(self))) 

84 

85 if self.comm is not None: 

86 if self.comm.Get_size() > 1: 

87 global_absval = 0.0 

88 global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval) 

89 else: 

90 global_absval = local_absval 

91 else: 

92 global_absval = local_absval 

93 

94 return float(global_absval) 

95 

96 def isend(self, dest=None, tag=None, comm=None): 

97 """ 

98 Routine for sending data forward in time (non-blocking) 

99 

100 Args: 

101 dest (int): target rank 

102 tag (int): communication tag 

103 comm: communicator 

104 

105 Returns: 

106 request handle 

107 """ 

108 return comm.Issend(self[:], dest=dest, tag=tag) 

109 

110 def irecv(self, source=None, tag=None, comm=None): 

111 """ 

112 Routine for receiving in time 

113 

114 Args: 

115 source (int): source rank 

116 tag (int): communication tag 

117 comm: communicator 

118 

119 Returns: 

120 None 

121 """ 

122 return comm.Irecv(self[:], source=source, tag=tag) 

123 

124 def bcast(self, root=None, comm=None): 

125 """ 

126 Routine for broadcasting values 

127 

128 Args: 

129 root (int): process with value to broadcast 

130 comm: communicator 

131 

132 Returns: 

133 broadcasted values 

134 """ 

135 comm.Bcast(self[:], root=root) 

136 return self 

137 

138 

139class imex_cupy_mesh(object): 

140 """ 

141 RHS data type for cupy_meshes with implicit and explicit components 

142 

143 This data type can be used to have RHS with 2 components (here implicit and explicit) 

144 

145 Attributes: 

146 impl (cupy_mesh.cupy_mesh): implicit part 

147 expl (cupy_mesh.cupy_mesh): explicit part 

148 """ 

149 

150 def __init__(self, init, val=0.0): 

151 """ 

152 Initialization routine 

153 

154 Args: 

155 init: can either be a tuple (one int per dimension) or a number (if only one dimension is requested) 

156 or another imex_cupy_mesh object 

157 val (float): an initial number (default: 0.0) 

158 Raises: 

159 DataError: if init is none of the types above 

160 """ 

161 

162 if isinstance(init, type(self)): 

163 self.impl = cupy_mesh(init.impl) 

164 self.expl = cupy_mesh(init.expl) 

165 elif ( 

166 isinstance(init, tuple) 

167 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

168 and isinstance(init[2], cp.dtype) 

169 ): 

170 self.impl = cupy_mesh(init, val=val) 

171 self.expl = cupy_mesh(init, val=val) 

172 # something is wrong, if none of the ones above hit 

173 else: 

174 raise DataError('something went wrong during %s initialization' % type(self)) 

175 

176 

177class comp2_cupy_mesh(object): 

178 """ 

179 RHS data type for cupy_meshes with 2 components 

180 

181 Attributes: 

182 comp1 (cupy_mesh.cupy_mesh): first part 

183 comp2 (cupy_mesh.cupy_mesh): second part 

184 """ 

185 

186 def __init__(self, init, val=0.0): 

187 """ 

188 Initialization routine 

189 

190 Args: 

191 init: can either be a tuple (one int per dimension) or a number (if only one dimension is requested) 

192 or another comp2_mesh object 

193 Raises: 

194 DataError: if init is none of the types above 

195 """ 

196 

197 if isinstance(init, type(self)): 

198 self.comp1 = cupy_mesh(init.comp1) 

199 self.comp2 = cupy_mesh(init.comp2) 

200 elif ( 

201 isinstance(init, tuple) 

202 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

203 and isinstance(init[2], cp.dtype) 

204 ): 

205 self.comp1 = cupy_mesh(init, val=val) 

206 self.comp2 = cupy_mesh(init, val=val) 

207 # something is wrong, if none of the ones above hit 

208 else: 

209 raise DataError('something went wrong during %s initialization' % type(self))