# Coverage for pySDC/implementations/sweeper_classes/imex_1st_order_mass.py: 96%

## 51 statements

, created at 2024-09-20 16:55 +0000

1from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order

4class imex_1st_order_mass(imex_1st_order):

5 """

6 Custom sweeper class, implements Sweeper.py

8 First-order IMEX sweeper using implicit/explicit Euler as base integrator, with mass or weighting matrix

9 """

11 def update_nodes(self):

12 """

13 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes

15 Returns:

16 None

17 """

19 # get current level and problem description

20 L = self.level

21 P = L.prob

23 # only if the level has been touched before

24 assert L.status.unlocked

26 # get number of collocation nodes for easier access

27 M = self.coll.num_nodes

29 # gather all terms which are known already (e.g. from the previous iteration)

30 # this corresponds to u0 + QF(u^k) - QIFI(u^k) - QEFE(u^k) + tau

32 # get QF(u^k)

33 integral = self.integrate()

35 # This is somewhat ugly, but we have to apply the mass matrix on u0 only on the finest level

36 if L.level_index == 0:

37 u0 = P.apply_mass_matrix(L.u[0])

38 else:

39 u0 = L.u[0]

41 for m in range(M):

42 # subtract QIFI(u^k)_m + QEFE(u^k)_m

43 for j in range(M + 1):

44 integral[m] -= L.dt * (self.QI[m + 1, j] * L.f[j].impl + self.QE[m + 1, j] * L.f[j].expl)

46 integral[m] += u0

47 # add tau if associated

48 if L.tau[m] is not None:

49 integral[m] += L.tau[m]

51 # do the sweep

52 for m in range(0, M):

53 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)

54 rhs = P.dtype_u(integral[m])

55 for j in range(m + 1):

56 rhs += L.dt * (self.QI[m + 1, j] * L.f[j].impl + self.QE[m + 1, j] * L.f[j].expl)

58 # implicit solve with prefactor stemming from QI

59 L.u[m + 1] = P.solve_system(

60 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m]

61 )

62 # update function values

63 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])

65 # indicate presence of new values at this level

66 L.status.updated = True

68 return None

70 def compute_end_point(self):

71 """

72 Compute u at the right point of the interval

74 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False

76 Returns:

77 None

78 """

80 # get current level and problem description

81 L = self.level

82 P = L.prob

84 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy

85 if self.coll.right_is_node and not self.params.do_coll_update:

86 # a copy is sufficient

87 L.uend = P.dtype_u(L.u[-1])

88 else:

89 raise NotImplementedError('Mass matrix sweeper expect u_M = u_end')

91 return None

93 def compute_residual(self, stage=None):

94 """

95 Computation of the residual using the collocation matrix Q

97 Args:

98 stage (str): The current stage of the step the level belongs to

99 """

101 # get current level and problem description

102 L = self.level

103 P = L.prob

105 # Check if we want to skip the residual computation to gain performance

106 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!

107 if stage in self.params.skip_residual_computation:

108 L.status.residual = 0.0 if L.status.residual is None else L.status.residual

109 return None

111 # check if there are new values (e.g. from a sweep)

112 # assert L.status.updated

114 # compute the residual for each node

116 # build QF(u)

117 res_norm = []

118 res = self.integrate()

119 for m in range(self.coll.num_nodes):

120 # This is somewhat ugly, but we have to apply the mass matrix on u0 only on the finest level

121 if L.level_index == 0:

122 res[m] += P.apply_mass_matrix(L.u[0] - L.u[m + 1])

123 else:

124 res[m] += L.u[0] - P.apply_mass_matrix(L.u[m + 1])

125 # add tau if associated

126 if L.tau[m] is not None:

127 res[m] += L.tau[m]

128 # Due to different boundary conditions we might have to fix the residual

129 if L.prob.fix_bc_for_residual:

130 L.prob.fix_residual(res[m])

131 # use abs function from data type here

132 res_norm.append(abs(res[m]))

134 # find maximal residual over the nodes

135 L.status.residual = max(res_norm)

137 # indicate that the residual has seen the new values

138 L.status.updated = False

140 return None