Coverage for pySDC/implementations/datatype_classes/cupy_mesh.py: 86%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 09:13 +0000

1import cupy as cp 

2 

3try: 

4 from mpi4py import MPI 

5except ImportError: 

6 MPI = None 

7 

8 

9class cupy_mesh(cp.ndarray): 

10 """ 

11 CuPy-based datatype for serial or parallel meshes. 

12 """ 

13 

14 comm = None 

15 xp = cp 

16 

17 def __new__(cls, init, val=0.0, **kwargs): 

18 """ 

19 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. 

20 

21 Args: 

22 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype 

23 val: value to initialize 

24 

25 Returns: 

26 obj of type mesh 

27 

28 """ 

29 if isinstance(init, cupy_mesh): 

30 obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs) 

31 obj[:] = init[:] 

32 elif ( 

33 isinstance(init, tuple) 

34 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

35 and isinstance(init[2], cp.dtype) 

36 ): 

37 obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs) 

38 obj.fill(val) 

39 cls.comm = init[1] 

40 else: 

41 raise NotImplementedError(type(init)) 

42 return obj 

43 

44 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): 

45 """ 

46 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs 

47 """ 

48 args = [] 

49 for _, input_ in enumerate(inputs): 

50 if isinstance(input_, cupy_mesh): 

51 args.append(input_.view(cp.ndarray)) 

52 else: 

53 args.append(input_) 

54 results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh) 

55 return results 

56 

57 def __abs__(self): 

58 """ 

59 Overloading the abs operator 

60 

61 Returns: 

62 float: absolute maximum of all mesh values 

63 """ 

64 # take absolute values of the mesh values 

65 local_absval = float(cp.amax(cp.ndarray.__abs__(self))) 

66 

67 if self.comm is not None: 

68 if self.comm.Get_size() > 1: 

69 global_absval = 0.0 

70 global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval) 

71 else: 

72 global_absval = local_absval 

73 else: 

74 global_absval = local_absval 

75 

76 return float(global_absval) 

77 

78 def isend(self, dest=None, tag=None, comm=None): 

79 """ 

80 Routine for sending data forward in time (non-blocking) 

81 

82 Args: 

83 dest (int): target rank 

84 tag (int): communication tag 

85 comm: communicator 

86 

87 Returns: 

88 request handle 

89 """ 

90 return comm.Issend(self[:], dest=dest, tag=tag) 

91 

92 def irecv(self, source=None, tag=None, comm=None): 

93 """ 

94 Routine for receiving in time 

95 

96 Args: 

97 source (int): source rank 

98 tag (int): communication tag 

99 comm: communicator 

100 

101 Returns: 

102 None 

103 """ 

104 return comm.Irecv(self[:], source=source, tag=tag) 

105 

106 def bcast(self, root=None, comm=None): 

107 """ 

108 Routine for broadcasting values 

109 

110 Args: 

111 root (int): process with value to broadcast 

112 comm: communicator 

113 

114 Returns: 

115 broadcasted values 

116 """ 

117 comm.Bcast(self[:], root=root) 

118 return self 

119 

120 

121class CuPyMultiComponentMesh(cupy_mesh): 

122 r""" 

123 Generic mesh with multiple components. 

124 

125 To make a specific multi-component mesh, derive from this class and list the components as strings in the class 

126 attribute ``components``. An example: 

127 

128 ``` 

129 class imex_cupy_mesh(CuPyMultiComponentMesh): 

130 components = ['impl', 'expl'] 

131 ``` 

132 

133 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access 

134 to the components with ``.``. Continuing the above example: 

135 

136 ``` 

137 init = ((100,), None, numpy.dtype('d')) 

138 f = imex_cupy_mesh(init) 

139 f.shape # (2, 100) 

140 f.expl.shape # (100,) 

141 ``` 

142 

143 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the 

144 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view 

145 on to the component if appropriate. Importantly, this means that you cannot name a component like something that 

146 is already an attribute of ``cupy_mesh`` or ``cupy.ndarray`` because this will not result in calls to ``__getattr__``. 

147 

148 There are a couple more things to keep in mind: 

149 - Because a ``CuPyMultiComponentMesh`` is just a ``cupy.ndarray`` with one more dimension, all components must have 

150 the same shape. 

151 - You can use the entire ``CuPyMultiComponentMesh`` like a ``cupy.ndarray`` in operations that accept arrays, but make 

152 sure that you really want to apply the same operation on all components if you do. 

153 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this 

154 leads to all kinds of trouble throughout the code. But here you really cannot get away without. 

155 """ 

156 

157 components = [] 

158 

159 def __new__(cls, init, *args, **kwargs): 

160 if isinstance(init, tuple): 

161 shape = (init[0],) if type(init[0]) is int else init[0] 

162 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs) 

163 else: 

164 obj = super().__new__(cls, init, *args, **kwargs) 

165 

166 return obj 

167 

168 def __getattr__(self, name): 

169 if name in self.components: 

170 if self.shape[0] == len(self.components): 

171 return self[self.components.index(name)].view(cupy_mesh) 

172 else: 

173 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.') 

174 else: 

175 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!") 

176 

177 

178class imex_cupy_mesh(CuPyMultiComponentMesh): 

179 components = ['impl', 'expl'] 

180 

181 

182class comp2_cupy_mesh(CuPyMultiComponentMesh): 

183 components = ['comp1', 'comp2']