Coverage for pySDC/implementations/transfer_classes/TransferMesh_FFT2D.py: 57%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-09 14:59 +0000

1import numpy as np 

2 

3from pySDC.core.errors import TransferError 

4from pySDC.core.space_transfer import SpaceTransfer 

5from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh 

6 

7 

8class mesh_to_mesh_fft2d(SpaceTransfer): 

9 """ 

10 Custon base_transfer class, implements Transfer.py 

11 

12 This implementation can restrict and prolong between 2d meshes with FFT for periodic boundaries 

13 

14 Attributes: 

15 Rspace: spatial restriction matrix, dim. Nf x Nc 

16 Pspace: spatial prolongation matrix, dim. Nc x Nf 

17 """ 

18 

19 def __init__(self, fine_prob, coarse_prob, params): 

20 """ 

21 Initialization routine 

22 

23 Args: 

24 fine_prob: fine problem 

25 coarse_prob: coarse problem 

26 params: parameters for the transfer operators 

27 """ 

28 # invoke super initialization 

29 super(mesh_to_mesh_fft2d, self).__init__(fine_prob, coarse_prob, params) 

30 

31 # TODO: cleanup and move to real-valued FFT 

32 assert len(self.fine_prob.nvars) == 2 

33 assert len(self.coarse_prob.nvars) == 2 

34 assert self.fine_prob.nvars[0] == self.fine_prob.nvars[1] 

35 assert self.coarse_prob.nvars[0] == self.coarse_prob.nvars[1] 

36 

37 self.ratio = int(self.fine_prob.nvars[0] / self.coarse_prob.nvars[0]) 

38 

39 def restrict(self, F): 

40 """ 

41 Restriction implementation 

42 

43 Args: 

44 F: the fine level data (easier to access than via the fine attribute) 

45 """ 

46 if isinstance(F, mesh): 

47 G = mesh(self.coarse_prob.init, val=0.0) 

48 G[:] = F[:: self.ratio, :: self.ratio] 

49 elif isinstance(F, imex_mesh): 

50 G = imex_mesh(self.coarse_prob.init, val=0.0) 

51 G.impl[:] = F.impl[:: self.ratio, :: self.ratio] 

52 G.expl[:] = F.expl[:: self.ratio, :: self.ratio] 

53 else: 

54 raise TransferError('Unknown data type, got %s' % type(F)) 

55 return G 

56 

57 def prolong(self, G): 

58 """ 

59 Prolongation implementation 

60 

61 Args: 

62 G: the coarse level data (easier to access than via the coarse attribute) 

63 """ 

64 if isinstance(G, mesh): 

65 F = mesh(self.fine_prob.init) 

66 tmpG = np.fft.fft2(G) 

67 tmpF = np.zeros(self.fine_prob.init[0], dtype=np.complex128) 

68 halfG = int(self.coarse_prob.init[0][0] / 2) 

69 tmpF[0:halfG, 0:halfG] = tmpG[0:halfG, 0:halfG] 

70 tmpF[self.fine_prob.init[0][0] - halfG :, 0:halfG] = tmpG[halfG:, 0:halfG] 

71 tmpF[0:halfG, self.fine_prob.init[0][0] - halfG :] = tmpG[0:halfG, halfG:] 

72 tmpF[self.fine_prob.init[0][0] - halfG :, self.fine_prob.init[0][0] - halfG :] = tmpG[halfG:, halfG:] 

73 F[:] = np.real(np.fft.ifft2(tmpF)) * self.ratio * 2 

74 elif isinstance(G, imex_mesh): 

75 F = imex_mesh(G) 

76 tmpG_impl = np.fft.fft2(G.impl) 

77 tmpF_impl = np.zeros(self.fine_prob.init, dtype=np.complex128) 

78 halfG = int(self.coarse_prob.init[0][0] / 2) 

79 tmpF_impl[0:halfG, 0:halfG] = tmpG_impl[0:halfG, 0:halfG] 

80 tmpF_impl[self.fine_prob.init[0][0] - halfG :, 0:halfG] = tmpG_impl[halfG:, 0:halfG] 

81 tmpF_impl[0:halfG, self.fine_prob.init[0][0] - halfG :] = tmpG_impl[0:halfG, halfG:] 

82 tmpF_impl[self.fine_prob.init[0][0] - halfG :, self.fine_prob.init[0][0] - halfG :] = tmpG_impl[ 

83 halfG:, halfG: 

84 ] 

85 F.impl[:] = np.real(np.fft.ifft2(tmpF_impl)) * self.ratio * 2 

86 tmpG_expl = np.fft.fft2(G.expl) / (self.coarse_prob.init[0] * self.coarse_prob.init[1]) 

87 tmpF_expl = np.zeros(self.fine_prob.init[0], dtype=np.complex128) 

88 halfG = int(self.coarse_prob.init[0][0] / 2) 

89 tmpF_expl[0:halfG, 0:halfG] = tmpG_expl[0:halfG, 0:halfG] 

90 tmpF_expl[self.fine_prob.init[0][0] - halfG :, 0:halfG] = tmpG_expl[halfG:, 0:halfG] 

91 tmpF_expl[0:halfG, self.fine_prob.init[0][0] - halfG :] = tmpG_expl[0:halfG, halfG:] 

92 tmpF_expl[self.fine_prob.init[0][0] - halfG :, self.fine_prob.init[0][0] - halfG :] = tmpG_expl[ 

93 halfG:, halfG: 

94 ] 

95 F.expl[:] = np.real(np.fft.ifft2(tmpF_expl)) * self.ratio * 2 

96 else: 

97 raise TransferError('Unknown data type, got %s' % type(G)) 

98 return F