Coverage for pySDC/implementations/transfer_classes/TransferMesh_MPIFFT.py: 88%

66 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1from pySDC.core.Errors import TransferError 

2from pySDC.core.SpaceTransfer import space_transfer 

3from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh 

4from mpi4py_fft import PFFT, newDistArray 

5 

6 

7class fft_to_fft(space_transfer): 

8 """ 

9 Custom base_transfer class, implements Transfer.py 

10 

11 This implementation can restrict and prolong between PMESH datatypes meshes with FFT for periodic boundaries 

12 

13 """ 

14 

15 def __init__(self, fine_prob, coarse_prob, params): 

16 """ 

17 Initialization routine 

18 

19 Args: 

20 fine_prob: fine problem 

21 coarse_prob: coarse problem 

22 params: parameters for the transfer operators 

23 """ 

24 # invoke super initialization 

25 super().__init__(fine_prob, coarse_prob, params) 

26 

27 assert self.fine_prob.spectral == self.coarse_prob.spectral 

28 

29 self.spectral = self.fine_prob.spectral 

30 

31 Nf = list(self.fine_prob.fft.global_shape()) 

32 Nc = list(self.coarse_prob.fft.global_shape()) 

33 self.ratio = [int(nf / nc) for nf, nc in zip(Nf, Nc)] 

34 axes = tuple(range(len(Nf))) 

35 

36 self.fft_pad = PFFT( 

37 self.coarse_prob.comm, 

38 Nc, 

39 padding=self.ratio, 

40 axes=axes, 

41 dtype=self.coarse_prob.fft.dtype(False), 

42 slab=True, 

43 ) 

44 

45 def restrict(self, F): 

46 """ 

47 Restriction implementation 

48 

49 Args: 

50 F: the fine level data (easier to access than via the fine attribute) 

51 """ 

52 G = type(F)(self.coarse_prob.init) 

53 

54 def _restrict(fine, coarse): 

55 if self.spectral: 

56 if hasattr(self.fine_prob, 'ncomp'): 

57 for i in range(self.fine_prob.ncomp): 

58 if fine.shape[-1] == self.fine_prob.ncomp: 

59 tmpF = newDistArray(self.fine_prob.fft, False) 

60 tmpF = self.fine_prob.fft.backward(fine[..., i], tmpF) 

61 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])] 

62 coarse[..., i] = self.coarse_prob.fft.forward(tmpG, coarse[..., i]) 

63 elif fine.shape[0] == self.fine_prob.ncomp: 

64 tmpF = newDistArray(self.fine_prob.fft, False) 

65 tmpF = self.fine_prob.fft.backward(fine[i, ...], tmpF) 

66 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])] 

67 coarse[i, ...] = self.coarse_prob.fft.forward(tmpG, coarse[i, ...]) 

68 else: 

69 raise TransferError('Don\'t know how to restrict for this problem with multiple components') 

70 else: 

71 tmpF = self.fine_prob.fft.backward(fine) 

72 tmpG = tmpF[:: int(self.ratio[0]), :: int(self.ratio[1])] 

73 coarse[:] = self.coarse_prob.fft.forward(tmpG, coarse) 

74 else: 

75 coarse[:] = fine[:: int(self.ratio[0]), :: int(self.ratio[1])] 

76 

77 if hasattr(type(F), 'components'): 

78 for comp in F.components: 

79 _restrict(F.__getattr__(comp), G.__getattr__(comp)) 

80 elif type(F).__name__ == 'mesh': 

81 _restrict(F, G) 

82 else: 

83 raise TransferError('Wrong data type for restriction, got %s' % type(F)) 

84 

85 return G 

86 

87 def prolong(self, G): 

88 """ 

89 Prolongation implementation 

90 

91 Args: 

92 G: the coarse level data (easier to access than via the coarse attribute) 

93 """ 

94 F = type(G)(self.fine_prob.init) 

95 

96 def _prolong(coarse, fine): 

97 if self.spectral: 

98 if hasattr(self.fine_prob, 'ncomp'): 

99 for i in range(self.fine_prob.ncomp): 

100 if coarse.shape[-1] == self.fine_prob.ncomp: 

101 tmpF = self.fft_pad.backward(coarse[..., i]) 

102 fine[..., i] = self.fine_prob.fft.forward(tmpF, fine[..., i]) 

103 elif coarse.shape[0] == self.fine_prob.ncomp: 

104 tmpF = self.fft_pad.backward(coarse[i, ...]) 

105 fine[i, ...] = self.fine_prob.fft.forward(tmpF, fine[i, ...]) 

106 else: 

107 raise TransferError('Don\'t know how to prolong for this problem with multiple components') 

108 

109 else: 

110 tmpF = self.fft_pad.backward(coarse) 

111 fine[:] = self.fine_prob.fft.forward(tmpF, fine) 

112 else: 

113 if hasattr(self.fine_prob, 'ncomp'): 

114 for i in range(self.fine_prob.ncomp): 

115 G_hat = self.coarse_prob.fft.forward(coarse[..., i]) 

116 fine[..., i] = self.fft_pad.backward(G_hat, fine[..., i]) 

117 else: 

118 G_hat = self.coarse_prob.fft.forward(coarse) 

119 fine[:] = self.fft_pad.backward(G_hat, fine) 

120 

121 if hasattr(type(F), 'components'): 

122 for comp in F.components: 

123 _prolong(G.__getattr__(comp), F.__getattr__(comp)) 

124 elif type(G).__name__ == 'mesh': 

125 _prolong(G, F) 

126 

127 else: 

128 raise TransferError('Unknown data type, got %s' % type(G)) 

129 

130 return F