Coverage for pySDC/implementations/datatype_classes/mesh.py: 100%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
3try:
4 # TODO : mpi4py cannot be imported before dolfin when using fenics mesh
5 # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
6 # This should be dealt with at some point
7 from mpi4py import MPI
8except ImportError:
9 MPI = None
12class mesh(np.ndarray):
13 """
14 Numpy-based datatype for serial or parallel meshes.
15 Can include a communicator and expects a dtype to allow complex data.
17 Attributes:
18 comm: MPI communicator or None
19 """
21 comm = None
22 xp = np
24 def __new__(cls, init, val=0.0, **kwargs):
25 """
26 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
28 Args:
29 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
30 val: value to initialize
32 Returns:
33 obj of type mesh
35 """
36 if isinstance(init, mesh):
37 obj = np.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
38 obj[:] = init[:]
39 elif (
40 isinstance(init, tuple)
41 and (init[1] is None or isinstance(init[1], MPI.Intracomm))
42 and isinstance(init[2], np.dtype)
43 ):
44 obj = np.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
45 obj.fill(val)
46 cls.comm = init[1]
47 else:
48 raise NotImplementedError(type(init))
49 return obj
51 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
52 """
53 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
54 """
55 args = []
56 for _, input_ in enumerate(inputs):
57 if isinstance(input_, mesh):
58 args.append(input_.view(np.ndarray))
59 else:
60 args.append(input_)
62 results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self))
63 return results
65 def __abs__(self):
66 """
67 Overloading the abs operator
69 Returns:
70 float: absolute maximum of all mesh values
71 """
72 # take absolute values of the mesh values
73 local_absval = float(np.max(np.ndarray.__abs__(self)))
75 if self.comm is not None:
76 if self.comm.size > 1:
77 global_absval = self.comm.allreduce(sendobj=local_absval, op=MPI.MAX)
78 else:
79 global_absval = local_absval
80 else:
81 global_absval = local_absval
83 return float(global_absval)
85 def isend(self, dest=None, tag=None, comm=None):
86 """
87 Routine for sending data forward in time (non-blocking)
89 Args:
90 dest (int): target rank
91 tag (int): communication tag
92 comm: communicator
94 Returns:
95 request handle
96 """
97 return comm.Issend(self[:], dest=dest, tag=tag)
99 def irecv(self, source=None, tag=None, comm=None):
100 """
101 Routine for receiving in time
103 Args:
104 source (int): source rank
105 tag (int): communication tag
106 comm: communicator
108 Returns:
109 None
110 """
111 return comm.Irecv(self[:], source=source, tag=tag)
113 def bcast(self, root=None, comm=None):
114 """
115 Routine for broadcasting values
117 Args:
118 root (int): process with value to broadcast
119 comm: communicator
121 Returns:
122 broadcasted values
123 """
124 comm.Bcast(self[:], root=root)
125 return self
128class MultiComponentMesh(mesh):
129 r"""
130 Generic mesh with multiple components.
132 To make a specific multi-component mesh, derive from this class and list the components as strings in the class
133 attribute ``components``. An example:
135 ```
136 class imex_mesh(MultiComponentMesh):
137 components = ['impl', 'expl']
138 ```
140 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access
141 to the components with ``.``. Continuing the above example:
143 ```
144 init = ((100,), None, numpy.dtype('d'))
145 f = imex_mesh(init)
146 f.shape # (2, 100)
147 f.expl.shape # (100,)
148 ```
150 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the
151 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view
152 on to the component if appropriate. Importantly, this means that you cannot name a component like something that
153 is already an attribute of ``mesh`` or ``numpy.ndarray`` because this will not result in calls to ``__getattr__``.
155 There are a couple more things to keep in mind:
156 - Because a ``MultiComponentMesh`` is just a ``numpy.ndarray`` with one more dimension, all components must have
157 the same shape.
158 - You can use the entire ``MultiComponentMesh`` like a ``numpy.ndarray`` in operations that accept arrays, but make
159 sure that you really want to apply the same operation on all components if you do.
160 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this
161 leads to all kinds of trouble throughout the code. But here you really cannot get away without.
162 """
164 components = []
166 def __new__(cls, init, *args, **kwargs):
167 if isinstance(init, tuple):
168 shape = (init[0],) if type(init[0]) is int else init[0]
169 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs)
170 else:
171 obj = super().__new__(cls, init, *args, **kwargs)
173 return obj
175 def __getattr__(self, name):
176 if name in self.components:
177 if self.shape[0] == len(self.components):
178 return self[self.components.index(name)].view(mesh)
179 else:
180 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.')
181 else:
182 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!")
185class imex_mesh(MultiComponentMesh):
186 components = ['impl', 'expl']
189class comp2_mesh(MultiComponentMesh):
190 components = ['comp1', 'comp2']