Coverage for pySDC/implementations/datatype_classes/cupy_mesh.py: 86%
57 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 09:13 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 09:13 +0000
1import cupy as cp
3try:
4 from mpi4py import MPI
5except ImportError:
6 MPI = None
9class cupy_mesh(cp.ndarray):
10 """
11 CuPy-based datatype for serial or parallel meshes.
12 """
14 comm = None
15 xp = cp
17 def __new__(cls, init, val=0.0, **kwargs):
18 """
19 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
21 Args:
22 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
23 val: value to initialize
25 Returns:
26 obj of type mesh
28 """
29 if isinstance(init, cupy_mesh):
30 obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
31 obj[:] = init[:]
32 elif (
33 isinstance(init, tuple)
34 and (init[1] is None or isinstance(init[1], MPI.Intracomm))
35 and isinstance(init[2], cp.dtype)
36 ):
37 obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
38 obj.fill(val)
39 cls.comm = init[1]
40 else:
41 raise NotImplementedError(type(init))
42 return obj
44 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
45 """
46 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
47 """
48 args = []
49 for _, input_ in enumerate(inputs):
50 if isinstance(input_, cupy_mesh):
51 args.append(input_.view(cp.ndarray))
52 else:
53 args.append(input_)
54 results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh)
55 return results
57 def __abs__(self):
58 """
59 Overloading the abs operator
61 Returns:
62 float: absolute maximum of all mesh values
63 """
64 # take absolute values of the mesh values
65 local_absval = float(cp.amax(cp.ndarray.__abs__(self)))
67 if self.comm is not None:
68 if self.comm.Get_size() > 1:
69 global_absval = 0.0
70 global_absval = max(self.comm.allreduce(sendobj=local_absval, op=MPI.MAX), global_absval)
71 else:
72 global_absval = local_absval
73 else:
74 global_absval = local_absval
76 return float(global_absval)
78 def isend(self, dest=None, tag=None, comm=None):
79 """
80 Routine for sending data forward in time (non-blocking)
82 Args:
83 dest (int): target rank
84 tag (int): communication tag
85 comm: communicator
87 Returns:
88 request handle
89 """
90 return comm.Issend(self[:], dest=dest, tag=tag)
92 def irecv(self, source=None, tag=None, comm=None):
93 """
94 Routine for receiving in time
96 Args:
97 source (int): source rank
98 tag (int): communication tag
99 comm: communicator
101 Returns:
102 None
103 """
104 return comm.Irecv(self[:], source=source, tag=tag)
106 def bcast(self, root=None, comm=None):
107 """
108 Routine for broadcasting values
110 Args:
111 root (int): process with value to broadcast
112 comm: communicator
114 Returns:
115 broadcasted values
116 """
117 comm.Bcast(self[:], root=root)
118 return self
121class CuPyMultiComponentMesh(cupy_mesh):
122 r"""
123 Generic mesh with multiple components.
125 To make a specific multi-component mesh, derive from this class and list the components as strings in the class
126 attribute ``components``. An example:
128 ```
129 class imex_cupy_mesh(CuPyMultiComponentMesh):
130 components = ['impl', 'expl']
131 ```
133 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access
134 to the components with ``.``. Continuing the above example:
136 ```
137 init = ((100,), None, numpy.dtype('d'))
138 f = imex_cupy_mesh(init)
139 f.shape # (2, 100)
140 f.expl.shape # (100,)
141 ```
143 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the
144 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view
145 on to the component if appropriate. Importantly, this means that you cannot name a component like something that
146 is already an attribute of ``cupy_mesh`` or ``cupy.ndarray`` because this will not result in calls to ``__getattr__``.
148 There are a couple more things to keep in mind:
149 - Because a ``CuPyMultiComponentMesh`` is just a ``cupy.ndarray`` with one more dimension, all components must have
150 the same shape.
151 - You can use the entire ``CuPyMultiComponentMesh`` like a ``cupy.ndarray`` in operations that accept arrays, but make
152 sure that you really want to apply the same operation on all components if you do.
153 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this
154 leads to all kinds of trouble throughout the code. But here you really cannot get away without.
155 """
157 components = []
159 def __new__(cls, init, *args, **kwargs):
160 if isinstance(init, tuple):
161 shape = (init[0],) if type(init[0]) is int else init[0]
162 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs)
163 else:
164 obj = super().__new__(cls, init, *args, **kwargs)
166 return obj
168 def __getattr__(self, name):
169 if name in self.components:
170 if self.shape[0] == len(self.components):
171 return self[self.components.index(name)].view(cupy_mesh)
172 else:
173 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.')
174 else:
175 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!")
178class imex_cupy_mesh(CuPyMultiComponentMesh):
179 components = ['impl', 'expl']
182class comp2_cupy_mesh(CuPyMultiComponentMesh):
183 components = ['comp1', 'comp2']