Coverage for pySDC/implementations/transfer_classes/TransferMesh_MPIFFT.py: 88%
66 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1from pySDC.core.Errors import TransferError
2from pySDC.core.SpaceTransfer import space_transfer
3from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
4from mpi4py_fft import PFFT, newDistArray
7class fft_to_fft(space_transfer):
8 """
9 Custom base_transfer class, implements Transfer.py
11 This implementation can restrict and prolong between PMESH datatypes meshes with FFT for periodic boundaries
13 """
15 def __init__(self, fine_prob, coarse_prob, params):
16 """
17 Initialization routine
19 Args:
20 fine_prob: fine problem
21 coarse_prob: coarse problem
22 params: parameters for the transfer operators
23 """
24 # invoke super initialization
25 super().__init__(fine_prob, coarse_prob, params)
27 assert self.fine_prob.spectral == self.coarse_prob.spectral
29 self.spectral = self.fine_prob.spectral
31 Nf = list(self.fine_prob.fft.global_shape())
32 Nc = list(self.coarse_prob.fft.global_shape())
33 self.ratio = [int(nf / nc) for nf, nc in zip(Nf, Nc)]
34 axes = tuple(range(len(Nf)))
36 self.fft_pad = PFFT(
37 self.coarse_prob.comm,
38 Nc,
39 padding=self.ratio,
40 axes=axes,
41 dtype=self.coarse_prob.fft.dtype(False),
42 slab=True,
43 )
45 def restrict(self, F):
46 """
47 Restriction implementation
49 Args:
50 F: the fine level data (easier to access than via the fine attribute)
51 """
52 G = type(F)(self.coarse_prob.init)
54 def _restrict(fine, coarse):
55 if self.spectral:
56 if hasattr(self.fine_prob, 'ncomp'):
57 for i in range(self.fine_prob.ncomp):
58 if fine.shape[-1] == self.fine_prob.ncomp:
59 tmpF = newDistArray(self.fine_prob.fft, False)
60 tmpF = self.fine_prob.fft.backward(fine[..., i], tmpF)
61 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])]
62 coarse[..., i] = self.coarse_prob.fft.forward(tmpG, coarse[..., i])
63 elif fine.shape[0] == self.fine_prob.ncomp:
64 tmpF = newDistArray(self.fine_prob.fft, False)
65 tmpF = self.fine_prob.fft.backward(fine[i, ...], tmpF)
66 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])]
67 coarse[i, ...] = self.coarse_prob.fft.forward(tmpG, coarse[i, ...])
68 else:
69 raise TransferError('Don\'t know how to restrict for this problem with multiple components')
70 else:
71 tmpF = self.fine_prob.fft.backward(fine)
72 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])]
73 coarse[:] = self.coarse_prob.fft.forward(tmpG, coarse)
74 else:
75 coarse[:] = fine[:: int(self.ratio[0]), :: int(self.ratio[1])]
77 if hasattr(type(F), 'components'):
78 for comp in F.components:
79 _restrict(F.__getattr__(comp), G.__getattr__(comp))
80 elif type(F).__name__ == 'mesh':
81 _restrict(F, G)
82 else:
83 raise TransferError('Wrong data type for restriction, got %s' % type(F))
85 return G
87 def prolong(self, G):
88 """
89 Prolongation implementation
91 Args:
92 G: the coarse level data (easier to access than via the coarse attribute)
93 """
94 F = type(G)(self.fine_prob.init)
96 def _prolong(coarse, fine):
97 if self.spectral:
98 if hasattr(self.fine_prob, 'ncomp'):
99 for i in range(self.fine_prob.ncomp):
100 if coarse.shape[-1] == self.fine_prob.ncomp:
101 tmpF = self.fft_pad.backward(coarse[..., i])
102 fine[..., i] = self.fine_prob.fft.forward(tmpF, fine[..., i])
103 elif coarse.shape[0] == self.fine_prob.ncomp:
104 tmpF = self.fft_pad.backward(coarse[i, ...])
105 fine[i, ...] = self.fine_prob.fft.forward(tmpF, fine[i, ...])
106 else:
107 raise TransferError('Don\'t know how to prolong for this problem with multiple components')
109 else:
110 tmpF = self.fft_pad.backward(coarse)
111 fine[:] = self.fine_prob.fft.forward(tmpF, fine)
112 else:
113 if hasattr(self.fine_prob, 'ncomp'):
114 for i in range(self.fine_prob.ncomp):
115 G_hat = self.coarse_prob.fft.forward(coarse[..., i])
116 fine[..., i] = self.fft_pad.backward(G_hat, fine[..., i])
117 else:
118 G_hat = self.coarse_prob.fft.forward(coarse)
119 fine[:] = self.fft_pad.backward(G_hat, fine)
121 if hasattr(type(F), 'components'):
122 for comp in F.components:
123 _prolong(G.__getattr__(comp), F.__getattr__(comp))
124 elif type(G).__name__ == 'mesh':
125 _prolong(G, F)
127 else:
128 raise TransferError('Unknown data type, got %s' % type(G))
130 return F