PFASST++
implicit_sweeper_impl.hpp
Go to the documentation of this file.
2 
3 #include <cassert>
4 
5 #include "pfasst/globals.hpp"
6 #include "pfasst/logging.hpp"
7 
8 using namespace std;
9 
10 namespace pfasst
11 {
12  namespace encap
13  {
14  template<typename scalar>
15  using lu_pair = pair< Matrix<scalar>, Matrix<scalar> >;
16 
20  template<typename scalar>
21  static lu_pair<scalar> lu_decomposition(const Matrix<scalar>& A)
22  {
23  assert(A.rows() == A.cols());
24 
25  auto n = A.rows();
26 
27  Matrix<scalar> L = Matrix<scalar>::Zero(n, n);
28  Matrix<scalar> U = Matrix<scalar>::Zero(n, n);
29 
30  if (A.rows() == 1) {
31 
32  L(0, 0) = 1.0;
33  U(0, 0) = A(0,0);
34 
35  } else {
36 
37  // first row of U is first row of A
38  auto U12 = A.block(0, 1, 1, n-1);
39 
40  // first column of L is first column of A / a11
41  auto L21 = A.block(1, 0, n-1, 1) / A(0, 0);
42 
43  // remove first row and column and recurse
44  auto A22 = A.block(1, 1, n-1, n-1);
45  Matrix<scalar> tmp = A22 - L21 * U12;
46  auto LU22 = lu_decomposition(tmp);
47 
48  L(0, 0) = 1.0;
49  U(0, 0) = A(0, 0);
50  L.block(1, 0, n-1, 1) = L21;
51  U.block(0, 1, 1, n-1) = U12;
52  L.block(1, 1, n-1, n-1) = get<0>(LU22);
53  U.block(1, 1, n-1, n-1) = get<1>(LU22);
54 
55  }
56 
57  return lu_pair<scalar>(L, U);
58  }
59 
63  template<typename time>
64  vector<time> augment(time t0, time dt, vector<time> const & nodes)
65  {
66  vector<time> t(1 + nodes.size());
67  t[0] = t0;
68  for (size_t m = 0; m < nodes.size(); m++) {
69  t[m+1] = t0 + dt * nodes[m];
70  }
71  return t;
72  }
73 
74  /*
75  * Implementations
76  */
77 
78  template<typename time>
80  {
81  if (this->quadrature->right_is_node()) {
82  this->end_state->copy(this->state.back());
83  } else {
84  vector<shared_ptr<Encapsulation<time>>> dst = { this->end_state };
85  dst[0]->copy(this->start_state);
86  dst[0]->mat_apply(dst, this->get_controller()->get_step_size(), this->quadrature->get_b_mat(), this->fs_impl, false);
87  }
88  }
89 
90  template<typename time>
91  void ImplicitSweeper<time>::setup(bool coarse)
92  {
94 
95  auto const nodes = this->quadrature->get_nodes();
96  auto const num_nodes = this->quadrature->get_num_nodes();
97 
98  if (this->quadrature->left_is_node()) {
99  ML_CLOG(INFO, "Sweeper", "implicit sweeper shouldn't include left endpoint");
100  throw ValueError("implicit sweeper shouldn't include left endpoint");
101  }
102 
103  for (size_t m = 0; m < num_nodes; m++) {
104  this->s_integrals.push_back(this->get_factory()->create(pfasst::encap::solution));
105  this->fs_impl.push_back(this->get_factory()->create(pfasst::encap::function));
106  }
107 
108  Matrix<time> QT = this->quadrature->get_q_mat().transpose();
109  auto lu = lu_decomposition(QT);
110  auto L = get<0>(lu);
111  auto U = get<1>(lu);
112  this->q_tilde = U.transpose();
113 
114  ML_CLOG(DEBUG, "Sweeper", "Q':" << endl << QT);
115  ML_CLOG(DEBUG, "Sweeper", "L:" << endl << L);
116  ML_CLOG(DEBUG, "Sweeper", "U:" << endl << U);
117  ML_CLOG(DEBUG, "Sweeper", "LU:" << endl << L * U);
118  ML_CLOG(DEBUG, "Sweeper", "q_tilde:" << endl << this->q_tilde);
119  }
120 
121  template<typename time>
123  {
124  UNUSED(initial);
125 
126  auto const dt = this->get_controller()->get_step_size();
127  auto const t = this->get_controller()->get_time();
128 
129  ML_CLOG(INFO, "Sweeper", "predicting step " << this->get_controller()->get_step() + 1
130  << " (t=" << t << ", dt=" << dt << ")");
131 
132  auto const anodes = augment(t, dt, this->quadrature->get_nodes());
133  for (size_t m = 0; m < anodes.size() - 1; ++m) {
134  this->impl_solve(this->fs_impl[m], this->state[m], anodes[m], anodes[m+1] - anodes[m],
135  m == 0 ? this->get_start_state() : this->state[m-1]);
136  }
137 
138  this->set_end_state();
139  }
140 
141  template<typename time>
143  {
144  auto const dt = this->get_controller()->get_step_size();
145  auto const t = this->get_controller()->get_time();
146 
147  ML_CLOG(INFO, "Sweeper", "sweeping on step " << this->get_controller()->get_step() + 1
148  << " in iteration " << this->get_controller()->get_iteration()
149  << " (dt=" << dt << ")");
150 
151  this->s_integrals[0]->mat_apply(this->s_integrals, dt, this->quadrature->get_s_mat(), this->fs_impl, true);
152  if (this->fas_corrections.size() > 0) {
153  for (size_t m = 0; m < this->s_integrals.size(); m++) {
154  this->s_integrals[m]->saxpy(1.0, this->fas_corrections[m]);
155  }
156  }
157 
158  for (size_t m = 0; m < this->s_integrals.size(); m++) {
159  for (size_t n = 0; n < m; n++) {
160  this->s_integrals[m]->saxpy(-dt*this->q_tilde(m, n), this->fs_impl[n]);
161  }
162  }
163 
164  shared_ptr<Encapsulation<time>> rhs = this->get_factory()->create(pfasst::encap::solution);
165 
166  auto const anodes = augment(t, dt, this->quadrature->get_nodes());
167  for (size_t m = 0; m < anodes.size() - 1; ++m) {
168  auto const ds = anodes[m+1] - anodes[m];
169  rhs->copy(m == 0 ? this->get_start_state() : this->state[m-1]);
170  rhs->saxpy(1.0, this->s_integrals[m]);
171  rhs->saxpy(-ds, this->fs_impl[m]);
172  for (size_t n = 0; n < m; n++) {
173  rhs->saxpy(dt*this->q_tilde(m, n), this->fs_impl[n]);
174  }
175  this->impl_solve(this->fs_impl[m], this->state[m], anodes[m], ds, rhs);
176  }
177  this->set_end_state();
178  }
179 
180  template<typename time>
182  {
183  this->start_state->copy(this->end_state);
184  }
185 
186  template<typename time>
187  void ImplicitSweeper<time>::reevaluate(bool initial_only)
188  {
189  if (initial_only) {
190  return;
191  }
192  auto const dt = this->get_controller()->get_step_size();
193  auto const t0 = this->get_controller()->get_time();
194  auto const nodes = this->quadrature->get_nodes();
195  for (size_t m = 0; m < nodes.size(); m++) {
196  this->f_impl_eval(this->fs_impl[m], this->state[m], t0 + dt * nodes[m]);
197  }
198  }
199 
200  template<typename time>
201  void ImplicitSweeper<time>::integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
202  {
203  dst[0]->mat_apply(dst, dt, this->quadrature->get_q_mat(), this->fs_impl, true);
204  }
205 
206  }
207 }
void setup(shared_ptr< WrapperInterface< scalar, time >> wrapper)
STL namespace.
Value exception.
Definition: interfaces.hpp:50
#define ML_CLOG(level, logger_id, x)
same as CLOG(level, logger, x) from easylogging++
Definition: logging.hpp:117
vector< time > augment(time t0, time dt, vector< time > const &nodes)
Augment nodes: nodes <- [t0] + dt * nodes.
virtual void setup(bool coarse) override
Setup (allocate etc) the sweeper.
pair< Matrix< scalar >, Matrix< scalar > > lu_pair
Data/solution encapsulation.
tuple t
Definition: plot.py:12
#define UNUSED(expr)
Denoting unused function parameters for omitting compiler warnings.
Definition: globals.hpp:32
Eigen::Matrix< scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > Matrix
float dt
Definition: plot.py:10
static lu_pair< scalar > lu_decomposition(const Matrix< scalar > &A)
LU (without pivoting) decomposition.