Coverage for pySDC/implementations/sweeper_classes/generic_implicit.py: 98%

56 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1from pySDC.core.Sweeper import sweeper 

2 

3 

4class generic_implicit(sweeper): 

5 """ 

6 Generic implicit sweeper, expecting lower triangular matrix type as input 

7 

8 Attributes: 

9 QI: lower triangular matrix 

10 """ 

11 

12 def __init__(self, params): 

13 """ 

14 Initialization routine for the custom sweeper 

15 

16 Args: 

17 params: parameters for the sweeper 

18 """ 

19 

20 if 'QI' not in params: 

21 params['QI'] = 'IE' 

22 

23 # call parent's initialization routine 

24 super().__init__(params) 

25 

26 # get QI matrix 

27 self.QI = self.get_Qdelta_implicit(self.coll, qd_type=self.params.QI) 

28 

29 def integrate(self): 

30 """ 

31 Integrates the right-hand side 

32 

33 Returns: 

34 list of dtype_u: containing the integral as values 

35 """ 

36 

37 L = self.level 

38 P = L.prob 

39 

40 me = [] 

41 

42 # integrate RHS over all collocation nodes 

43 for m in range(1, self.coll.num_nodes + 1): 

44 # new instance of dtype_u, initialize values with 0 

45 me.append(P.dtype_u(P.init, val=0.0)) 

46 for j in range(1, self.coll.num_nodes + 1): 

47 me[-1] += L.dt * self.coll.Qmat[m, j] * L.f[j] 

48 

49 return me 

50 

51 def update_nodes(self): 

52 """ 

53 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes 

54 

55 Returns: 

56 None 

57 """ 

58 

59 L = self.level 

60 P = L.prob 

61 

62 # only if the level has been touched before 

63 assert L.status.unlocked 

64 

65 # get number of collocation nodes for easier access 

66 M = self.coll.num_nodes 

67 

68 # update the MIN-SR-FLEX preconditioner 

69 if self.params.QI.startswith('MIN-SR-FLEX'): 

70 k = L.status.sweep 

71 if k > M: 

72 self.params.QI = "MIN-SR-S" 

73 else: 

74 self.params.QI = 'MIN-SR-FLEX' + str(k) 

75 self.QI = self.get_Qdelta_implicit(self.coll, qd_type=self.params.QI) 

76 

77 # gather all terms which are known already (e.g. from the previous iteration) 

78 # this corresponds to u0 + QF(u^k) - QdF(u^k) + tau 

79 

80 # get QF(u^k) 

81 integral = self.integrate() 

82 for m in range(M): 

83 # get -QdF(u^k)_m 

84 for j in range(1, M + 1): 

85 integral[m] -= L.dt * self.QI[m + 1, j] * L.f[j] 

86 

87 # add initial value 

88 integral[m] += L.u[0] 

89 # add tau if associated 

90 if L.tau[m] is not None: 

91 integral[m] += L.tau[m] 

92 

93 # do the sweep 

94 for m in range(0, M): 

95 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1) 

96 rhs = P.dtype_u(integral[m]) 

97 for j in range(1, m + 1): 

98 rhs += L.dt * self.QI[m + 1, j] * L.f[j] 

99 

100 # implicit solve with prefactor stemming from the diagonal of Qd 

101 alpha = L.dt * self.QI[m + 1, m + 1] 

102 if alpha == 0: 

103 L.u[m + 1] = rhs 

104 else: 

105 L.u[m + 1] = P.solve_system(rhs, alpha, L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

106 # update function values 

107 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m]) 

108 

109 # indicate presence of new values at this level 

110 L.status.updated = True 

111 

112 return None 

113 

114 def compute_end_point(self): 

115 """ 

116 Compute u at the right point of the interval 

117 

118 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False 

119 

120 Returns: 

121 None 

122 """ 

123 

124 L = self.level 

125 P = L.prob 

126 

127 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy 

128 if self.coll.right_is_node and not self.params.do_coll_update: 

129 # a copy is sufficient 

130 L.uend = P.dtype_u(L.u[-1]) 

131 else: 

132 # start with u0 and add integral over the full interval (using coll.weights) 

133 L.uend = P.dtype_u(L.u[0]) 

134 for m in range(self.coll.num_nodes): 

135 L.uend += L.dt * self.coll.weights[m] * L.f[m + 1] 

136 # add up tau correction of the full interval (last entry) 

137 if L.tau[-1] is not None: 

138 L.uend += L.tau[-1] 

139 

140 return None