Coverage for pySDC/implementations/sweeper_classes/imex_1st_order_mass.py: 96%
51 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 16:55 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-20 16:55 +0000
1from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
4class imex_1st_order_mass(imex_1st_order):
5 """
6 Custom sweeper class, implements Sweeper.py
8 First-order IMEX sweeper using implicit/explicit Euler as base integrator, with mass or weighting matrix
9 """
11 def update_nodes(self):
12 """
13 Update the u- and f-values at the collocation nodes -> corresponds to a single sweep over all nodes
15 Returns:
16 None
17 """
19 # get current level and problem description
20 L = self.level
21 P = L.prob
23 # only if the level has been touched before
24 assert L.status.unlocked
26 # get number of collocation nodes for easier access
27 M = self.coll.num_nodes
29 # gather all terms which are known already (e.g. from the previous iteration)
30 # this corresponds to u0 + QF(u^k) - QIFI(u^k) - QEFE(u^k) + tau
32 # get QF(u^k)
33 integral = self.integrate()
35 # This is somewhat ugly, but we have to apply the mass matrix on u0 only on the finest level
36 if L.level_index == 0:
37 u0 = P.apply_mass_matrix(L.u[0])
38 else:
39 u0 = L.u[0]
41 for m in range(M):
42 # subtract QIFI(u^k)_m + QEFE(u^k)_m
43 for j in range(M + 1):
44 integral[m] -= L.dt * (self.QI[m + 1, j] * L.f[j].impl + self.QE[m + 1, j] * L.f[j].expl)
45 # add initial value
46 integral[m] += u0
47 # add tau if associated
48 if L.tau[m] is not None:
49 integral[m] += L.tau[m]
51 # do the sweep
52 for m in range(0, M):
53 # build rhs, consisting of the known values from above and new values from previous nodes (at k+1)
54 rhs = P.dtype_u(integral[m])
55 for j in range(m + 1):
56 rhs += L.dt * (self.QI[m + 1, j] * L.f[j].impl + self.QE[m + 1, j] * L.f[j].expl)
58 # implicit solve with prefactor stemming from QI
59 L.u[m + 1] = P.solve_system(
60 rhs, L.dt * self.QI[m + 1, m + 1], L.u[m + 1], L.time + L.dt * self.coll.nodes[m]
61 )
62 # update function values
63 L.f[m + 1] = P.eval_f(L.u[m + 1], L.time + L.dt * self.coll.nodes[m])
65 # indicate presence of new values at this level
66 L.status.updated = True
68 return None
70 def compute_end_point(self):
71 """
72 Compute u at the right point of the interval
74 The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
76 Returns:
77 None
78 """
80 # get current level and problem description
81 L = self.level
82 P = L.prob
84 # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
85 if self.coll.right_is_node and not self.params.do_coll_update:
86 # a copy is sufficient
87 L.uend = P.dtype_u(L.u[-1])
88 else:
89 raise NotImplementedError('Mass matrix sweeper expect u_M = u_end')
91 return None
93 def compute_residual(self, stage=None):
94 """
95 Computation of the residual using the collocation matrix Q
97 Args:
98 stage (str): The current stage of the step the level belongs to
99 """
101 # get current level and problem description
102 L = self.level
103 P = L.prob
105 # Check if we want to skip the residual computation to gain performance
106 # Keep in mind that skipping any residual computation is likely to give incorrect outputs of the residual!
107 if stage in self.params.skip_residual_computation:
108 L.status.residual = 0.0 if L.status.residual is None else L.status.residual
109 return None
111 # check if there are new values (e.g. from a sweep)
112 # assert L.status.updated
114 # compute the residual for each node
116 # build QF(u)
117 res_norm = []
118 res = self.integrate()
119 for m in range(self.coll.num_nodes):
120 # This is somewhat ugly, but we have to apply the mass matrix on u0 only on the finest level
121 if L.level_index == 0:
122 res[m] += P.apply_mass_matrix(L.u[0] - L.u[m + 1])
123 else:
124 res[m] += L.u[0] - P.apply_mass_matrix(L.u[m + 1])
125 # add tau if associated
126 if L.tau[m] is not None:
127 res[m] += L.tau[m]
128 # Due to different boundary conditions we might have to fix the residual
129 if L.prob.fix_bc_for_residual:
130 L.prob.fix_residual(res[m])
131 # use abs function from data type here
132 res_norm.append(abs(res[m]))
134 # find maximal residual over the nodes
135 L.status.residual = max(res_norm)
137 # indicate that the residual has seen the new values
138 L.status.updated = False
140 return None