PFASST++
encap_sweeper_impl.hpp
Go to the documentation of this file.
2 
3 #include <algorithm>
4 #include <cassert>
5 using namespace std;
6 
7 #include "pfasst/globals.hpp"
8 #include "pfasst/config.hpp"
9 
10 
11 namespace pfasst
12 {
13  namespace encap
14  {
15  template<typename time>
17  : quadrature(nullptr)
18  , abs_residual_tol(0.0)
19  , rel_residual_tol(0.0)
20  {}
21 
22  template<typename time>
23  shared_ptr<Encapsulation<time>> EncapSweeper<time>::get_state(size_t m) const
24  {
25  return this->state[m];
26  }
27 
28  template<typename time>
29  shared_ptr<Encapsulation<time>> EncapSweeper<time>::get_tau(size_t m) const
30  {
31  return this->fas_corrections[m];
32  }
33 
34  template<typename time>
35  shared_ptr<Encapsulation<time>> EncapSweeper<time>::get_saved_state(size_t m) const
36  {
37  return this->saved_state[m];
38  }
39 
52  template<typename time>
54  {
55  this->abs_residual_tol = time(config::get_value<double>("abs_res_tol", this->abs_residual_tol));
56  this->rel_residual_tol = time(config::get_value<double>("rel_res_tol", this->rel_residual_tol));
57  }
58 
67  template<typename time>
68  void EncapSweeper<time>::setup(bool coarse)
69  {
70  auto const nodes = this->quadrature->get_nodes();
71  auto const num_nodes = this->quadrature->get_num_nodes();
72 
73  this->start_state = this->get_factory()->create(pfasst::encap::solution);
74  this->end_state = this->get_factory()->create(pfasst::encap::solution);
75 
76  for (size_t m = 0; m < num_nodes; m++) {
77  this->state.push_back(this->get_factory()->create(pfasst::encap::solution));
78  if (coarse) {
79  this->saved_state.push_back(this->get_factory()->create(pfasst::encap::solution));
80  }
81  }
82 
83  if (coarse) {
84  for (size_t m = 0; m < num_nodes; m++) {
85  this->fas_corrections.push_back(this->get_factory()->create(pfasst::encap::solution));
86  }
87  }
88  }
89 
94  template<typename time>
96  {
97  for (size_t m = 1; m < this->quadrature->get_num_nodes(); m++) {
98  this->state[m]->copy(this->state[0]);
99  }
100  }
101 
102  template<typename time>
103  void EncapSweeper<time>::save(bool initial_only)
104  {
105  // XXX: if !left_is_node, this is a problem...
106  if (initial_only) {
107  this->saved_state[0]->copy(state[0]);
108  } else {
109  for (size_t m = 0; m < this->saved_state.size(); m++) {
110  this->saved_state[m]->copy(state[m]);
111  }
112  }
113  }
114 
115  template<typename time>
117  {
118  this->quadrature = quadrature;
119  }
120 
121  template<typename time>
122  shared_ptr<const IQuadrature<time>> EncapSweeper<time>::get_quadrature() const
123  {
124  return this->quadrature;
125  }
126 
127  template<typename time>
128  shared_ptr<Encapsulation<time>> EncapSweeper<time>::get_start_state() const
129  {
130  return this->start_state;
131  }
132 
133  template<typename time>
134  const vector<time> EncapSweeper<time>::get_nodes() const
135  {
136  return this->quadrature->get_nodes();
137  }
138 
139  template<typename time>
141  {
142  this->factory = factory;
143  }
144 
145  template<typename time>
146  shared_ptr<EncapFactory<time>> EncapSweeper<time>::get_factory() const
147  {
148  return factory;
149  }
150 
151  template<typename time>
152  shared_ptr<Encapsulation<time>> EncapSweeper<time>::get_end_state()
153  {
154  return this->end_state;
155  }
156 
160  template<typename time>
162  {
163  throw NotImplementedYet("sweeper");
164  }
165 
169  template<typename time>
170  void EncapSweeper<time>::reevaluate(bool initial_only)
171  {
172  UNUSED(initial_only);
173  throw NotImplementedYet("sweeper");
174  }
175 
179  template<typename time>
180  void EncapSweeper<time>::integrate(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
181  {
182  UNUSED(dt); UNUSED(dst);
183  throw NotImplementedYet("sweeper");
184  }
185 
186  template<typename time>
187  void EncapSweeper<time>::set_residual_tolerances(time abs_residual_tol, time rel_residual_tol,
188  int order)
189  {
190  this->abs_residual_tol = abs_residual_tol;
191  this->rel_residual_tol = rel_residual_tol;
192  this->residual_norm_order = order;
193  }
194 
198  template<typename time>
199  void EncapSweeper<time>::residual(time dt, vector<shared_ptr<Encapsulation<time>>> dst) const
200  {
201  UNUSED(dt); UNUSED(dst);
202  throw NotImplementedYet("residual");
203  }
204 
217  template<typename time>
219  {
220  if (this->abs_residual_tol > 0.0 || this->rel_residual_tol > 0.0) {
221  if (this->residuals.size() == 0) {
222  for (size_t m = 0; m < this->get_nodes().size(); m++) {
223  this->residuals.push_back(this->get_factory()->create(pfasst::encap::solution));
224  }
225  }
226  this->residual(this->get_controller()->get_step_size(), this->residuals);
227  vector<time> anorms, rnorms;
228  for (size_t m = 0; m < this->get_nodes().size(); m++) {
229  anorms.push_back(this->residuals[m]->norm0());
230  rnorms.push_back(anorms.back() / this->get_state(m)->norm0());
231  }
232  auto amax = *std::max_element(anorms.begin(), anorms.end());
233  auto rmax = *std::max_element(rnorms.begin(), rnorms.end());
234  if (amax < this->abs_residual_tol || rmax < this->rel_residual_tol) {
235  return true;
236  }
237  }
238  return false;
239  }
240 
241  template<typename time>
243  {
244  this->start_state->post(comm, tag);
245  }
246 
247  template<typename time>
248  void EncapSweeper<time>::send(ICommunicator* comm, int tag, bool blocking)
249  {
250  this->end_state->send(comm, tag, blocking);
251  }
252 
253  template<typename time>
254  void EncapSweeper<time>::recv(ICommunicator* comm, int tag, bool blocking)
255  {
256  this->start_state->recv(comm, tag, blocking);
257  if (this->quadrature->left_is_node()) {
258  this->state[0]->copy(this->start_state);
259  }
260  }
261 
262  template<typename time>
264  {
265  if (comm->rank() == comm->size() - 1) {
266  this->start_state->copy(this->end_state);
267  }
268  this->start_state->broadcast(comm);
269  }
270 
271 
272  template<typename time>
274  {
275  shared_ptr<EncapSweeper<time>> y = dynamic_pointer_cast<EncapSweeper<time>>(x);
276  assert(y);
277  return *y.get();
278  }
279 
280  template<typename time>
282  {
283  shared_ptr<const EncapSweeper<time>> y = dynamic_pointer_cast<const EncapSweeper<time>>(x);
284  assert(y);
285  return *y.get();
286  }
287 
288  } // ::pfasst::encap
289 } // ::pfasst
virtual bool converged() override
Return convergence status.
virtual void recv(ICommunicator *comm, int tag, bool blocking) override
virtual void send(ICommunicator *comm, int tag, bool blocking) override
virtual shared_ptr< Encapsulation< time > > get_end_state()
Host based encapsulated base sweeper.
Interface for quadrature handlers.
Definition: interface.hpp:232
Not implemented yet exception.
Definition: interfaces.hpp:29
STL namespace.
virtual shared_ptr< Encapsulation< time > > get_state(size_t m) const
Retrieve solution values of current iteration at time node index m.
void set_residual_tolerances(time abs_residual_tol, time rel_residual_tol, int order=0)
Set residual tolerances for convergence checking.
virtual void integrate(time dt, vector< shared_ptr< Encapsulation< time >>> dst) const
Integrates values of right hand side at all time nodes \( t \in [0,M-1] \) simultaneously.
virtual shared_ptr< const IQuadrature< time > > get_quadrature() const
virtual void spread() override
Initialize solution values at all time nodes with meaningful values.
EncapSweeper< time > & as_encap_sweeper(shared_ptr< ISweeper< time >> x)
virtual int size()=0
Abstract interface of factory for creating Encapsulation objects.
virtual void save(bool initial_only) override
Save states (and/or function values) at all nodes.
virtual void setup(bool coarse) override
Setup (allocate etc) the sweeper.
virtual void set_factory(shared_ptr< EncapFactory< time >> factory)
virtual void reevaluate(bool initial_only=false)
Re-evaluate function values.
virtual void residual(time dt, vector< shared_ptr< Encapsulation< time >>> dst) const
Compute residual at each SDC node (including FAS corrections).
virtual shared_ptr< Encapsulation< time > > get_saved_state(size_t m) const
Retrieve solution values of previous iteration at time node index m.
virtual int rank()=0
virtual void set_options() override
Set options from command line etc.
Data/solution encapsulation.
virtual void broadcast(ICommunicator *comm) override
#define UNUSED(expr)
Denoting unused function parameters for omitting compiler warnings.
Definition: globals.hpp:32
virtual const vector< time > get_nodes() const
virtual void advance() override
Advance from one time step to the next.
Abstract interface for communicators.
Definition: interfaces.hpp:70
virtual shared_ptr< Encapsulation< time > > get_start_state() const
virtual shared_ptr< Encapsulation< time > > get_tau(size_t m) const
Retrieve FAS correction of current iteration at time node index m.
virtual shared_ptr< EncapFactory< time > > get_factory() const
Abstract SDC sweeper.
Definition: interfaces.hpp:164
virtual void post(ICommunicator *comm, int tag) override
float dt
Definition: plot.py:10
static precision norm0(const vector< precision > &data)
virtual void set_quadrature(shared_ptr< IQuadrature< time >> quadrature)