Coverage for pySDC/implementations/datatype_classes/mesh.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1import numpy as np 

2 

3try: 

4 # TODO : mpi4py cannot be imported before dolfin when using fenics mesh 

5 # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590 

6 # This should be dealt with at some point 

7 from mpi4py import MPI 

8except ImportError: 

9 MPI = None 

10 

11 

12class mesh(np.ndarray): 

13 """ 

14 Numpy-based datatype for serial or parallel meshes. 

15 Can include a communicator and expects a dtype to allow complex data. 

16 

17 Attributes: 

18 comm: MPI communicator or None 

19 """ 

20 

21 comm = None 

22 xp = np 

23 

24 def __new__(cls, init, val=0.0, **kwargs): 

25 """ 

26 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. 

27 

28 Args: 

29 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype 

30 val: value to initialize 

31 

32 Returns: 

33 obj of type mesh 

34 

35 """ 

36 if isinstance(init, mesh): 

37 obj = np.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs) 

38 obj[:] = init[:] 

39 elif ( 

40 isinstance(init, tuple) 

41 and (init[1] is None or isinstance(init[1], MPI.Intracomm)) 

42 and isinstance(init[2], np.dtype) 

43 ): 

44 obj = np.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs) 

45 obj.fill(val) 

46 cls.comm = init[1] 

47 else: 

48 raise NotImplementedError(type(init)) 

49 return obj 

50 

51 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): 

52 """ 

53 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs 

54 """ 

55 args = [] 

56 for _, input_ in enumerate(inputs): 

57 if isinstance(input_, mesh): 

58 args.append(input_.view(np.ndarray)) 

59 else: 

60 args.append(input_) 

61 

62 results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self)) 

63 return results 

64 

65 def __abs__(self): 

66 """ 

67 Overloading the abs operator 

68 

69 Returns: 

70 float: absolute maximum of all mesh values 

71 """ 

72 # take absolute values of the mesh values 

73 local_absval = float(np.max(np.ndarray.__abs__(self))) 

74 

75 if self.comm is not None: 

76 if self.comm.size > 1: 

77 global_absval = self.comm.allreduce(sendobj=local_absval, op=MPI.MAX) 

78 else: 

79 global_absval = local_absval 

80 else: 

81 global_absval = local_absval 

82 

83 return float(global_absval) 

84 

85 def isend(self, dest=None, tag=None, comm=None): 

86 """ 

87 Routine for sending data forward in time (non-blocking) 

88 

89 Args: 

90 dest (int): target rank 

91 tag (int): communication tag 

92 comm: communicator 

93 

94 Returns: 

95 request handle 

96 """ 

97 return comm.Issend(self[:], dest=dest, tag=tag) 

98 

99 def irecv(self, source=None, tag=None, comm=None): 

100 """ 

101 Routine for receiving in time 

102 

103 Args: 

104 source (int): source rank 

105 tag (int): communication tag 

106 comm: communicator 

107 

108 Returns: 

109 None 

110 """ 

111 return comm.Irecv(self[:], source=source, tag=tag) 

112 

113 def bcast(self, root=None, comm=None): 

114 """ 

115 Routine for broadcasting values 

116 

117 Args: 

118 root (int): process with value to broadcast 

119 comm: communicator 

120 

121 Returns: 

122 broadcasted values 

123 """ 

124 comm.Bcast(self[:], root=root) 

125 return self 

126 

127 

128class MultiComponentMesh(mesh): 

129 r""" 

130 Generic mesh with multiple components. 

131 

132 To make a specific multi-component mesh, derive from this class and list the components as strings in the class 

133 attribute ``components``. An example: 

134 

135 ``` 

136 class imex_mesh(MultiComponentMesh): 

137 components = ['impl', 'expl'] 

138 ``` 

139 

140 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access 

141 to the components with ``.``. Continuing the above example: 

142 

143 ``` 

144 init = ((100,), None, numpy.dtype('d')) 

145 f = imex_mesh(init) 

146 f.shape # (2, 100) 

147 f.expl.shape # (100,) 

148 ``` 

149 

150 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the 

151 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view 

152 on to the component if appropriate. Importantly, this means that you cannot name a component like something that 

153 is already an attribute of ``mesh`` or ``numpy.ndarray`` because this will not result in calls to ``__getattr__``. 

154 

155 There are a couple more things to keep in mind: 

156 - Because a ``MultiComponentMesh`` is just a ``numpy.ndarray`` with one more dimension, all components must have 

157 the same shape. 

158 - You can use the entire ``MultiComponentMesh`` like a ``numpy.ndarray`` in operations that accept arrays, but make 

159 sure that you really want to apply the same operation on all components if you do. 

160 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this 

161 leads to all kinds of trouble throughout the code. But here you really cannot get away without. 

162 """ 

163 

164 components = [] 

165 

166 def __new__(cls, init, *args, **kwargs): 

167 if isinstance(init, tuple): 

168 shape = (init[0],) if type(init[0]) is int else init[0] 

169 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs) 

170 else: 

171 obj = super().__new__(cls, init, *args, **kwargs) 

172 

173 return obj 

174 

175 def __getattr__(self, name): 

176 if name in self.components: 

177 if self.shape[0] == len(self.components): 

178 return self[self.components.index(name)].view(mesh) 

179 else: 

180 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.') 

181 else: 

182 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!") 

183 

184 

185class imex_mesh(MultiComponentMesh): 

186 components = ['impl', 'expl'] 

187 

188 

189class comp2_mesh(MultiComponentMesh): 

190 components = ['comp1', 'comp2']