Coverage for pySDC/implementations/datatype_classes/mesh.py: 97%
68 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1import numpy as np
3from pySDC.core.Errors import DataError
5try:
6 # TODO : mpi4py cannot be imported before dolfin when using fenics mesh
7 # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
8 # This should be dealt with at some point
9 from mpi4py import MPI
10except ImportError:
11 MPI = None
14class mesh(np.ndarray):
15 """
16 Numpy-based datatype for serial or parallel meshes.
17 Can include a communicator and expects a dtype to allow complex data.
19 Attributes:
20 _comm: MPI communicator or None
21 """
23 def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None):
24 """
25 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
27 Args:
28 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
29 val: value to initialize
31 Returns:
32 obj of type mesh
34 """
35 if isinstance(init, mesh):
36 obj = np.ndarray.__new__(
37 cls, shape=init.shape, dtype=init.dtype, buffer=buffer, offset=offset, strides=strides, order=order
38 )
39 obj[:] = init[:]
40 obj._comm = init._comm
41 elif (
42 isinstance(init, tuple)
43 and (init[1] is None or isinstance(init[1], MPI.Intracomm))
44 and isinstance(init[2], np.dtype)
45 ):
46 obj = np.ndarray.__new__(
47 cls, init[0], dtype=init[2], buffer=buffer, offset=offset, strides=strides, order=order
48 )
49 obj.fill(val)
50 obj._comm = init[1]
51 else:
52 raise NotImplementedError(type(init))
53 return obj
55 @property
56 def comm(self):
57 """
58 Getter for the communicator
59 """
60 return self._comm
62 def __array_finalize__(self, obj):
63 """
64 Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
65 """
66 if obj is None:
67 return
68 self._comm = getattr(obj, '_comm', None)
70 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
71 """
72 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
73 """
74 args = []
75 comm = None
76 for _, input_ in enumerate(inputs):
77 if isinstance(input_, mesh):
78 args.append(input_.view(np.ndarray))
79 comm = input_.comm
80 else:
81 args.append(input_)
83 results = super(mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self))
84 if type(self) == type(results):
85 results._comm = comm
86 return results
88 def __abs__(self):
89 """
90 Overloading the abs operator
92 Returns:
93 float: absolute maximum of all mesh values
94 """
95 # take absolute values of the mesh values
96 local_absval = float(np.amax(np.ndarray.__abs__(self)))
98 if self.comm is not None:
99 if self.comm.Get_size() > 1:
100 global_absval = 0.0
101 global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval)
102 else:
103 global_absval = local_absval
104 else:
105 global_absval = local_absval
107 return float(global_absval)
109 def isend(self, dest=None, tag=None, comm=None):
110 """
111 Routine for sending data forward in time (non-blocking)
113 Args:
114 dest (int): target rank
115 tag (int): communication tag
116 comm: communicator
118 Returns:
119 request handle
120 """
121 return comm.Issend(self[:], dest=dest, tag=tag)
123 def irecv(self, source=None, tag=None, comm=None):
124 """
125 Routine for receiving in time
127 Args:
128 source (int): source rank
129 tag (int): communication tag
130 comm: communicator
132 Returns:
133 None
134 """
135 return comm.Irecv(self[:], source=source, tag=tag)
137 def bcast(self, root=None, comm=None):
138 """
139 Routine for broadcasting values
141 Args:
142 root (int): process with value to broadcast
143 comm: communicator
145 Returns:
146 broadcasted values
147 """
148 comm.Bcast(self[:], root=root)
149 return self
152class MultiComponentMesh(mesh):
153 r"""
154 Generic mesh with multiple components.
156 To make a specific multi-component mesh, derive from this class and list the components as strings in the class
157 attribute ``components``. An example:
159 ```
160 class imex_mesh(MultiComponentMesh):
161 components = ['impl', 'expl']
162 ```
164 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access
165 to the components with ``.``. Continuing the above example:
167 ```
168 init = ((100,), None, numpy.dtype('d'))
169 f = imex_mesh(init)
170 f.shape # (2, 100)
171 f.expl.shape # (100,)
172 ```
174 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the
175 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view
176 on to the component if appropriate. Importantly, this means that you cannot name a component like something that
177 is already an attribute of ``mesh`` or ``numpy.ndarray`` because this will not result in calls to ``__getattr__``.
179 There are a couple more things to keep in mind:
180 - Because a ``MultiComponentMesh`` is just a ``numpy.ndarray`` with one more dimension, all components must have
181 the same shape.
182 - You can use the entire ``MultiComponentMesh`` like a ``numpy.ndarray`` in operations that accept arrays, but make
183 sure that you really want to apply the same operation on all components if you do.
184 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this
185 leads to all kinds of trouble throughout the code. But here you really cannot get away without.
186 """
188 components = []
190 def __new__(cls, init, *args, **kwargs):
191 if isinstance(init, tuple):
192 shape = (init[0],) if type(init[0]) is int else init[0]
193 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs)
194 else:
195 obj = super().__new__(cls, init, *args, **kwargs)
197 return obj
199 def __getattr__(self, name):
200 if name in self.components:
201 if self.shape[0] == len(self.components):
202 return self[self.components.index(name)].view(mesh)
203 else:
204 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.')
205 else:
206 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!")
209class imex_mesh(MultiComponentMesh):
210 components = ['impl', 'expl']
213class comp2_mesh(MultiComponentMesh):
214 components = ['comp1', 'comp2']