Coverage for pySDC/implementations/transfer_classes/TransferMesh_FFT.py: 100%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1import numpy as np
3from pySDC.core.errors import TransferError
4from pySDC.core.space_transfer import SpaceTransfer
7class mesh_to_mesh_fft(SpaceTransfer):
8 """
9 Custom base_transfer class, implements Transfer.py
11 This implementation can restrict and prolong between 1d meshes with FFT for periodic boundaries
13 Attributes:
14 irfft_object_fine: planned FFT for backward transformation, real-valued output
15 rfft_object_coarse: planned real-valued FFT for forward transformation
16 """
18 def __init__(self, fine_prob, coarse_prob, params):
19 """
20 Initialization routine
22 Args:
23 fine_prob: fine problem
24 coarse_prob: coarse problem
25 params: parameters for the transfer operators
26 """
27 # invoke super initialization
28 super().__init__(fine_prob, coarse_prob, params)
30 self.ratio = int(self.fine_prob.nvars / self.coarse_prob.nvars)
32 def restrict(self, F):
33 """
34 Restriction implementation
36 Args:
37 F: the fine level data (easier to access than via the fine attribute)
38 """
39 G = type(F)(self.coarse_prob.init, val=0.0)
41 if type(F).__name__ == 'mesh':
42 G[:] = F[:: self.ratio]
43 elif type(F).__name__ == 'imex_mesh':
44 G.impl[:] = F.impl[:: self.ratio]
45 G.expl[:] = F.expl[:: self.ratio]
46 else:
47 raise TransferError('Unknown data type, got %s' % type(F))
48 return G
50 def prolong(self, G):
51 """
52 Prolongation implementation
54 Args:
55 G: the coarse level data (easier to access than via the coarse attribute)
56 """
57 F = type(G)(self.fine_prob.init, val=0.0)
59 def _prolong(coarse):
60 coarse_hat = np.fft.rfft(coarse)
61 fine_hat = np.zeros(self.fine_prob.init[0] // 2 + 1, dtype=np.complex128)
62 half_idx = self.coarse_prob.init[0] // 2
63 fine_hat[0:half_idx] = coarse_hat[0:half_idx]
64 fine_hat[-1] = coarse_hat[-1]
65 return np.fft.irfft(fine_hat) * self.ratio
67 if type(G).__name__ == 'mesh':
68 F[:] = _prolong(G)
69 elif type(G).__name__ == 'imex_mesh':
70 F.impl[:] = _prolong(G.impl)
71 F.expl[:] = _prolong(G.expl)
72 else:
73 raise TransferError('Unknown data type, got %s' % type(G))
74 return F