Coverage for pySDC/implementations/datatype_classes/mesh.py: 97%

68 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import numpy as np 

2 

3from pySDC.core.Errors import DataError 

4 

5try: 

6 # TODO : mpi4py cannot be imported before dolfin when using fenics mesh 

7 # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590 

8 # This should be dealt with at some point 

9 from mpi4py import MPI 

10except ImportError: 

11 MPI = None 

12 

13 

14class mesh(np.ndarray): 

15 """ 

16 Numpy-based datatype for serial or parallel meshes. 

17 Can include a communicator and expects a dtype to allow complex data. 

18 

19 Attributes: 

20 _comm: MPI communicator or None 

21 """ 

22 

23 def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None): 

24 """ 

25 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. 

26 

27 Args: 

28 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype 

29 val: value to initialize 

30 

31 Returns: 

32 obj of type mesh 

33 

34 """ 

35 if isinstance(init, mesh): 

36 obj = np.ndarray.__new__( 

37 cls, shape=init.shape, dtype=init.dtype, buffer=buffer, offset=offset, strides=strides, order=order 

38 ) 

39 obj[:] = init[:] 

40 obj._comm = init._comm 

41 elif ( 

42 isinstance(init, tuple) 

43 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

44 and isinstance(init[2], np.dtype) 

45 ): 

46 obj = np.ndarray.__new__( 

47 cls, init[0], dtype=init[2], buffer=buffer, offset=offset, strides=strides, order=order 

48 ) 

49 obj.fill(val) 

50 obj._comm = init[1] 

51 else: 

52 raise NotImplementedError(type(init)) 

53 return obj 

54 

55 @property 

56 def comm(self): 

57 """ 

58 Getter for the communicator 

59 """ 

60 return self._comm 

61 

62 def __array_finalize__(self, obj): 

63 """ 

64 Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator. 

65 """ 

66 if obj is None: 

67 return 

68 self._comm = getattr(obj, '_comm', None) 

69 

70 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): 

71 """ 

72 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs 

73 """ 

74 args = [] 

75 comm = None 

76 for _, input_ in enumerate(inputs): 

77 if isinstance(input_, mesh): 

78 args.append(input_.view(np.ndarray)) 

79 comm = input_.comm 

80 else: 

81 args.append(input_) 

82 

83 results = super(mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self)) 

84 if type(self) == type(results): 

85 results._comm = comm 

86 return results 

87 

88 def __abs__(self): 

89 """ 

90 Overloading the abs operator 

91 

92 Returns: 

93 float: absolute maximum of all mesh values 

94 """ 

95 # take absolute values of the mesh values 

96 local_absval = float(np.amax(np.ndarray.__abs__(self))) 

97 

98 if self.comm is not None: 

99 if self.comm.Get_size() > 1: 

100 global_absval = 0.0 

101 global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval) 

102 else: 

103 global_absval = local_absval 

104 else: 

105 global_absval = local_absval 

106 

107 return float(global_absval) 

108 

109 def isend(self, dest=None, tag=None, comm=None): 

110 """ 

111 Routine for sending data forward in time (non-blocking) 

112 

113 Args: 

114 dest (int): target rank 

115 tag (int): communication tag 

116 comm: communicator 

117 

118 Returns: 

119 request handle 

120 """ 

121 return comm.Issend(self[:], dest=dest, tag=tag) 

122 

123 def irecv(self, source=None, tag=None, comm=None): 

124 """ 

125 Routine for receiving in time 

126 

127 Args: 

128 source (int): source rank 

129 tag (int): communication tag 

130 comm: communicator 

131 

132 Returns: 

133 None 

134 """ 

135 return comm.Irecv(self[:], source=source, tag=tag) 

136 

137 def bcast(self, root=None, comm=None): 

138 """ 

139 Routine for broadcasting values 

140 

141 Args: 

142 root (int): process with value to broadcast 

143 comm: communicator 

144 

145 Returns: 

146 broadcasted values 

147 """ 

148 comm.Bcast(self[:], root=root) 

149 return self 

150 

151 

152class MultiComponentMesh(mesh): 

153 r""" 

154 Generic mesh with multiple components. 

155 

156 To make a specific multi-component mesh, derive from this class and list the components as strings in the class 

157 attribute ``components``. An example: 

158 

159 ``` 

160 class imex_mesh(MultiComponentMesh): 

161 components = ['impl', 'expl'] 

162 ``` 

163 

164 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access 

165 to the components with ``.``. Continuing the above example: 

166 

167 ``` 

168 init = ((100,), None, numpy.dtype('d')) 

169 f = imex_mesh(init) 

170 f.shape # (2, 100) 

171 f.expl.shape # (100,) 

172 ``` 

173 

174 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the 

175 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view 

176 on to the component if appropriate. Importantly, this means that you cannot name a component like something that 

177 is already an attribute of ``mesh`` or ``numpy.ndarray`` because this will not result in calls to ``__getattr__``. 

178 

179 There are a couple more things to keep in mind: 

180 - Because a ``MultiComponentMesh`` is just a ``numpy.ndarray`` with one more dimension, all components must have 

181 the same shape. 

182 - You can use the entire ``MultiComponentMesh`` like a ``numpy.ndarray`` in operations that accept arrays, but make 

183 sure that you really want to apply the same operation on all components if you do. 

184 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this 

185 leads to all kinds of trouble throughout the code. But here you really cannot get away without. 

186 """ 

187 

188 components = [] 

189 

190 def __new__(cls, init, *args, **kwargs): 

191 if isinstance(init, tuple): 

192 shape = (init[0],) if type(init[0]) is int else init[0] 

193 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs) 

194 else: 

195 obj = super().__new__(cls, init, *args, **kwargs) 

196 

197 return obj 

198 

199 def __getattr__(self, name): 

200 if name in self.components: 

201 if self.shape[0] == len(self.components): 

202 return self[self.components.index(name)].view(mesh) 

203 else: 

204 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.') 

205 else: 

206 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!") 

207 

208 

209class imex_mesh(MultiComponentMesh): 

210 components = ['impl', 'expl'] 

211 

212 

213class comp2_mesh(MultiComponentMesh): 

214 components = ['comp1', 'comp2']