Coverage for pySDC/implementations/datatype_classes/cupy_mesh.py: 81%
63 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-16 14:51 +0000
1import cupy as cp
3try:
4 from mpi4py import MPI
5except ImportError:
6 MPI = None
8try:
9 from pySDC.helpers.NCCL_communicator import NCCLComm
10except ImportError:
11 NCCLComm = None
14class cupy_mesh(cp.ndarray):
15 """
16 CuPy-based datatype for serial or parallel meshes.
17 """
19 comm = None
20 xp = cp
22 def __new__(cls, init, val=0.0, **kwargs):
23 """
24 Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
26 Args:
27 init: either another mesh or a tuple containing the dimensions, the communicator and the dtype
28 val: value to initialize
30 Returns:
31 obj of type mesh
33 """
34 if isinstance(init, cupy_mesh):
35 obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
36 obj[:] = init[:]
37 elif (
38 isinstance(init, tuple)
39 and (init[1] is None or isinstance(init[1], MPI.Intracomm) or isinstance(init[1], NCCLComm))
40 and isinstance(init[2], cp.dtype)
41 ):
42 obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
43 obj.fill(val)
44 cls.comm = init[1]
45 else:
46 raise NotImplementedError(type(init))
47 return obj
49 def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
50 """
51 Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
52 """
53 args = []
54 for _, input_ in enumerate(inputs):
55 if isinstance(input_, cupy_mesh):
56 args.append(input_.view(cp.ndarray))
57 else:
58 args.append(input_)
59 results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh)
60 return results
62 def __abs__(self):
63 """
64 Overloading the abs operator
66 Returns:
67 float: absolute maximum of all mesh values
68 """
69 # take absolute values of the mesh values
70 local_absval = cp.max(cp.ndarray.__abs__(self))
72 if self.comm is not None:
73 if self.comm.Get_size() > 1:
74 global_absval = local_absval * 0
75 if isinstance(self.comm, NCCLComm):
76 self.comm.Allreduce(sendbuf=local_absval, recvbuf=global_absval, op=MPI.MAX)
77 else:
78 global_absval = self.comm.allreduce(sendobj=float(local_absval), op=MPI.MAX)
79 else:
80 global_absval = local_absval
81 else:
82 global_absval = local_absval
84 return float(global_absval)
86 def isend(self, dest=None, tag=None, comm=None):
87 """
88 Routine for sending data forward in time (non-blocking)
90 Args:
91 dest (int): target rank
92 tag (int): communication tag
93 comm: communicator
95 Returns:
96 request handle
97 """
98 return comm.Issend(self[:], dest=dest, tag=tag)
100 def irecv(self, source=None, tag=None, comm=None):
101 """
102 Routine for receiving in time
104 Args:
105 source (int): source rank
106 tag (int): communication tag
107 comm: communicator
109 Returns:
110 None
111 """
112 return comm.Irecv(self[:], source=source, tag=tag)
114 def bcast(self, root=None, comm=None):
115 """
116 Routine for broadcasting values
118 Args:
119 root (int): process with value to broadcast
120 comm: communicator
122 Returns:
123 broadcasted values
124 """
125 comm.Bcast(self[:], root=root)
126 return self
129class CuPyMultiComponentMesh(cupy_mesh):
130 r"""
131 Generic mesh with multiple components.
133 To make a specific multi-component mesh, derive from this class and list the components as strings in the class
134 attribute ``components``. An example:
136 ```
137 class imex_cupy_mesh(CuPyMultiComponentMesh):
138 components = ['impl', 'expl']
139 ```
141 Instantiating such a mesh will expand the mesh along an added first dimension for each component and allow access
142 to the components with ``.``. Continuing the above example:
144 ```
145 init = ((100,), None, numpy.dtype('d'))
146 f = imex_cupy_mesh(init)
147 f.shape # (2, 100)
148 f.expl.shape # (100,)
149 ```
151 Note that the components are not attributes of the mesh: ``"expl" in dir(f)`` will return False! Rather, the
152 components are handled in ``__getattr__``. This function is called if an attribute is not found and returns a view
153 on to the component if appropriate. Importantly, this means that you cannot name a component like something that
154 is already an attribute of ``cupy_mesh`` or ``cupy.ndarray`` because this will not result in calls to ``__getattr__``.
156 There are a couple more things to keep in mind:
157 - Because a ``CuPyMultiComponentMesh`` is just a ``cupy.ndarray`` with one more dimension, all components must have
158 the same shape.
159 - You can use the entire ``CuPyMultiComponentMesh`` like a ``cupy.ndarray`` in operations that accept arrays, but make
160 sure that you really want to apply the same operation on all components if you do.
161 - If you omit the assignment operator ``[:]`` during assignment, you will not change the mesh at all. Omitting this
162 leads to all kinds of trouble throughout the code. But here you really cannot get away without.
163 """
165 components = []
167 def __new__(cls, init, *args, **kwargs):
168 if isinstance(init, tuple):
169 shape = (init[0],) if type(init[0]) is int else init[0]
170 obj = super().__new__(cls, ((len(cls.components), *shape), *init[1:]), *args, **kwargs)
171 else:
172 obj = super().__new__(cls, init, *args, **kwargs)
174 return obj
176 def __getattr__(self, name):
177 if name in self.components:
178 if self.shape[0] == len(self.components):
179 return self[self.components.index(name)].view(cupy_mesh)
180 else:
181 raise AttributeError(f'Cannot access {name!r} in {type(self)!r} because the shape is unexpected.')
182 else:
183 raise AttributeError(f"{type(self)!r} does not have attribute {name!r}!")
186class imex_cupy_mesh(CuPyMultiComponentMesh):
187 components = ['impl', 'expl']
190class comp2_cupy_mesh(CuPyMultiComponentMesh):
191 components = ['comp1', 'comp2']