Coverage for pySDC/implementations/transfer_classes/TransferMesh_MPIFFT.py: 89%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1from pySDC.core.errors import TransferError
2from pySDC.core.space_transfer import SpaceTransfer
3from mpi4py_fft import PFFT, newDistArray
6class fft_to_fft(SpaceTransfer):
7 """
8 Custom base_transfer class, implements Transfer.py
10 This implementation can restrict and prolong between PMESH datatypes meshes with FFT for periodic boundaries
12 """
14 def __init__(self, fine_prob, coarse_prob, params):
15 """
16 Initialization routine
18 Args:
19 fine_prob: fine problem
20 coarse_prob: coarse problem
21 params: parameters for the transfer operators
22 """
23 # invoke super initialization
24 super().__init__(fine_prob, coarse_prob, params)
26 assert self.fine_prob.spectral == self.coarse_prob.spectral
28 self.spectral = self.fine_prob.spectral
30 Nf = list(self.fine_prob.fft.global_shape())
31 Nc = list(self.coarse_prob.fft.global_shape())
32 self.ratio = [int(nf / nc) for nf, nc in zip(Nf, Nc)]
33 axes = tuple(range(len(Nf)))
35 fft_args = {}
36 useGPU = 'cupy' in self.fine_prob.dtype_u.__name__.lower()
37 if useGPU:
38 fft_args['backend'] = 'cupy'
39 fft_args['comm_backend'] = 'NCCL'
41 self.fft_pad = PFFT(
42 self.coarse_prob.comm,
43 Nc,
44 padding=self.ratio,
45 axes=axes,
46 dtype=self.coarse_prob.fft.dtype(False),
47 slab=True,
48 **fft_args,
49 )
51 def restrict(self, F):
52 """
53 Restriction implementation
55 Args:
56 F: the fine level data (easier to access than via the fine attribute)
57 """
58 G = type(F)(self.coarse_prob.init)
60 def _restrict(fine, coarse):
61 if self.spectral:
62 if hasattr(self.fine_prob, 'ncomp'):
63 for i in range(self.fine_prob.ncomp):
64 if fine.shape[-1] == self.fine_prob.ncomp:
65 tmpF = newDistArray(self.fine_prob.fft, False)
66 tmpF = self.fine_prob.fft.backward(fine[..., i], tmpF)
67 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])]
68 coarse[..., i] = self.coarse_prob.fft.forward(tmpG, coarse[..., i])
69 elif fine.shape[0] == self.fine_prob.ncomp:
70 tmpF = newDistArray(self.fine_prob.fft, False)
71 tmpF = self.fine_prob.fft.backward(fine[i, ...], tmpF)
72 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])]
73 coarse[i, ...] = self.coarse_prob.fft.forward(tmpG, coarse[i, ...])
74 else:
75 raise TransferError('Don\'t know how to restrict for this problem with multiple components')
76 else:
77 tmpF = self.fine_prob.fft.backward(fine)
78 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])]
79 coarse[:] = self.coarse_prob.fft.forward(tmpG, coarse)
80 else:
81 coarse[:] = fine[:: int(self.ratio[0]), :: int(self.ratio[1])]
83 if hasattr(type(F), 'components'):
84 for comp in F.components:
85 _restrict(F.__getattr__(comp), G.__getattr__(comp))
86 elif type(F).__name__ in ['mesh', 'cupy_mesh']:
87 _restrict(F, G)
88 else:
89 raise TransferError('Wrong data type for restriction, got %s' % type(F))
91 return G
93 def prolong(self, G):
94 """
95 Prolongation implementation
97 Args:
98 G: the coarse level data (easier to access than via the coarse attribute)
99 """
100 F = type(G)(self.fine_prob.init)
102 def _prolong(coarse, fine):
103 if self.spectral:
104 if hasattr(self.fine_prob, 'ncomp'):
105 for i in range(self.fine_prob.ncomp):
106 if coarse.shape[-1] == self.fine_prob.ncomp:
107 tmpF = self.fft_pad.backward(coarse[..., i])
108 fine[..., i] = self.fine_prob.fft.forward(tmpF, fine[..., i])
109 elif coarse.shape[0] == self.fine_prob.ncomp:
110 tmpF = self.fft_pad.backward(coarse[i, ...])
111 fine[i, ...] = self.fine_prob.fft.forward(tmpF, fine[i, ...])
112 else:
113 raise TransferError('Don\'t know how to prolong for this problem with multiple components')
115 else:
116 tmpF = self.fft_pad.backward(coarse)
117 fine[:] = self.fine_prob.fft.forward(tmpF, fine)
118 else:
119 if hasattr(self.fine_prob, 'ncomp'):
120 for i in range(self.fine_prob.ncomp):
121 G_hat = self.coarse_prob.fft.forward(coarse[..., i])
122 fine[..., i] = self.fft_pad.backward(G_hat, fine[..., i])
123 else:
124 G_hat = self.coarse_prob.fft.forward(coarse)
125 fine[:] = self.fft_pad.backward(G_hat, fine)
127 if hasattr(type(F), 'components'):
128 for comp in F.components:
129 _prolong(G.__getattr__(comp), F.__getattr__(comp))
130 elif type(G).__name__ in ['mesh', 'cupy_mesh']:
131 _prolong(G, F)
133 else:
134 raise TransferError('Unknown data type, got %s' % type(G))
136 return F