Coverage for pySDC/implementations/transfer_classes/TransferMesh_FFT.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1import numpy as np 

2 

3from pySDC.core.errors import TransferError 

4from pySDC.core.space_transfer import SpaceTransfer 

5 

6 

7class mesh_to_mesh_fft(SpaceTransfer): 

8 """ 

9 Custom base_transfer class, implements Transfer.py 

10 

11 This implementation can restrict and prolong between 1d meshes with FFT for periodic boundaries 

12 

13 Attributes: 

14 irfft_object_fine: planned FFT for backward transformation, real-valued output 

15 rfft_object_coarse: planned real-valued FFT for forward transformation 

16 """ 

17 

18 def __init__(self, fine_prob, coarse_prob, params): 

19 """ 

20 Initialization routine 

21 

22 Args: 

23 fine_prob: fine problem 

24 coarse_prob: coarse problem 

25 params: parameters for the transfer operators 

26 """ 

27 # invoke super initialization 

28 super().__init__(fine_prob, coarse_prob, params) 

29 

30 self.ratio = int(self.fine_prob.nvars / self.coarse_prob.nvars) 

31 

32 def restrict(self, F): 

33 """ 

34 Restriction implementation 

35 

36 Args: 

37 F: the fine level data (easier to access than via the fine attribute) 

38 """ 

39 G = type(F)(self.coarse_prob.init, val=0.0) 

40 

41 if type(F).__name__ == 'mesh': 

42 G[:] = F[:: self.ratio] 

43 elif type(F).__name__ == 'imex_mesh': 

44 G.impl[:] = F.impl[:: self.ratio] 

45 G.expl[:] = F.expl[:: self.ratio] 

46 else: 

47 raise TransferError('Unknown data type, got %s' % type(F)) 

48 return G 

49 

50 def prolong(self, G): 

51 """ 

52 Prolongation implementation 

53 

54 Args: 

55 G: the coarse level data (easier to access than via the coarse attribute) 

56 """ 

57 F = type(G)(self.fine_prob.init, val=0.0) 

58 

59 def _prolong(coarse): 

60 coarse_hat = np.fft.rfft(coarse) 

61 fine_hat = np.zeros(self.fine_prob.init[0] // 2 + 1, dtype=np.complex128) 

62 half_idx = self.coarse_prob.init[0] // 2 

63 fine_hat[0:half_idx] = coarse_hat[0:half_idx] 

64 fine_hat[-1] = coarse_hat[-1] 

65 return np.fft.irfft(fine_hat) * self.ratio 

66 

67 if type(G).__name__ == 'mesh': 

68 F[:] = _prolong(G) 

69 elif type(G).__name__ == 'imex_mesh': 

70 F.impl[:] = _prolong(G.impl) 

71 F.expl[:] = _prolong(G.expl) 

72 else: 

73 raise TransferError('Unknown data type, got %s' % type(G)) 

74 return F