Source code for pypint.solutions.iterative_solution

# coding=utf-8
"""

.. moduleauthor:: Torbjörn Klatt <[email protected]>
"""
import warnings
import copy

import numpy as np

from pypint.solutions.i_solution import ISolution
from pypint.solutions.data_storage.step_solution_data import StepSolutionData
from pypint.solutions.data_storage.trajectory_solution_data import TrajectorySolutionData
from pypint.utilities import assert_is_instance, assert_condition


[docs]class IterativeSolution(ISolution): """Storage for the solutions of an iterative solver. A new solution of a specific iteration can be added via :py:meth:`.add_solution` and queried via :py:meth:`.solution`. Examples -------- By default, the internal solution data storage type is :py:class:`.TrajectorySolutionData` allowing for storage of the whole trajectory of the solution over the course of several iterations. However, this can be changed on initialization to only store the solution of the last time point over the course of iterations: >>> from pypint.solutions.iterative_solution import IterativeSolution >>> from pypint.solutions.data_storage.step_solution_data import StepSolutionData >>> my_reduced_full_solution = IterativeSolution(solution_data_type=StepSolutionData) """
[docs] def __init__(self, *args, **kwargs): """ Parameters ---------- solution_data_type : :py:class:`.TrajectorySolutionData` or :py:class:`.StepSolutionData` Defaults to :py:class:`.TrajectorySolutionData`. """ super(IterativeSolution, self).__init__(*args, **kwargs) self._data = [] # As this solution stores all values of all nodes of one iteration, `TrajectorySolutionData` is the solution # data type. self._data_type = kwargs['solution_data_type'] if 'solution_data_type' in kwargs else TrajectorySolutionData self._error_reduction = {} self._solution_reduction = {}
[docs] def add_solution(self, *args, **kwargs): """Adds a new solution data storage object. After each call an internal consistency check is carried out, which might raise further exceptions. The number of used iterations (see :py:attr:`.used_iterations`) is auto-incremented on success. Parameters ---------- data : :py:class:`.TrajectorySolutionData` or :py:class:`.StepSolutionData` *(must not be named)* Exactly one unnamed argument must be given. iteration : :py:class:`int` *(optional)* 1-based index of the iteration. Defaults to `-1` meaning append after last stored iteration. Raises ------ ValueError : * if ``iteration`` is not an integer * if ``iteration`` is not a valid index for the current size of stored solution data objects * if not exactly one solution data object is given """ if 'iteration' in kwargs: assert_is_instance(kwargs['iteration'], int, descriptor="Iteration Index", checking_obj=self) _iteration = kwargs['iteration'] - 1 if _iteration > 0: assert_condition(_iteration in range(-1, len(self._data)), ValueError, message=("Iteration index must be within the size of the solution data array: " "{:d} not in [0, {:d}]".format(_iteration, len(self._data))), checking_obj=self) # remove the `iteration` key from the keyword arguments so it does not get passed onto the solution data # storage creation del kwargs['iteration'] else: _iteration = -1 assert_condition(len(args) == 1 or 'data' in kwargs, ValueError, message="Exactly one solution data object or 'data' must be given.", checking_obj=self) assert_is_instance(args[0], self._data_type, descriptor="Solution Data Storage", checking_obj=self) _old_data = copy.copy(self._data) # backup for potential rollback if _iteration == -1: self._data.append(args[0]) else: self._data.insert(_iteration, args[0]) try: self._check_consistency() except ValueError: # consistency check failed, thus removing recently added solution data storage warnings.warn("Consistency Check failed. Not adding this solution.") self._data = copy.copy(_old_data) # rollback finally: # everything ok pass self._used_iterations += 1
[docs] def solution(self, iteration): """Accessor for the solution of a specific iteration. Parameters ---------- iteration : :py:class:`int` 0-based index of the iteration. ``-1`` means last iteration. Returns ------- solution : instance of :py:attr:`.data_storage_type` or :py:class:`None` if no solutions are stored. Raises ------ ValueError : If given ``iteration`` index is not in the valid range. """ if len(self._data) > 0: assert_condition(iteration in range(-1, len(self._data)), ValueError, message="Iteration index not within valid range: {:d} not in [-1, {:d}" .format(iteration, len(self._data)), checking_obj=self) return self._data[iteration] else: return None
[docs] def error(self, iteration): """Accessor for the errors of a specific iteration. Parameters ---------- iteration : :py:class:`int` 0-based index of the iteration. ``-1`` means last iteration. Returns ------- error : instance of :py:attr:`.Error` or :py:class:`None` if no solutions are stored. Raises ------ ValueError : If given ``iteration`` index is not in the valid range. """ if len(self._data) > 0: assert_condition(iteration in range(-1, len(self._data)), ValueError, message="Iteration index not within valid range: {:d} not in [-1, {:d}" .format(iteration, len(self._data)), checking_obj=self) if self._data_type == StepSolutionData: return np.array(self._data[iteration].error, dtype=np.object) else: return self._data[iteration].errors else: return None
[docs] def residual(self, iteration): """Accessor for the residuals of a specific iteration. Parameters ---------- iteration : :py:class:`int` 0-based index of the iteration. ``-1`` means last iteration. Returns ------- residual : instance of :py:attr:`.Residual` or :py:class:`None` if no solutions are stored. Raises ------ ValueError : If given ``iteration`` index is not in the valid range. """ if len(self._data) > 0: assert_condition(iteration in range(-1, len(self._data)), ValueError, message="Iteration index not within valid range: {:d} not in [-1, {:d}" .format(iteration, len(self._data)), checking_obj=self) if self._data_type == StepSolutionData: return np.array(self._data[iteration].residual, dtype=np.object) else: return self._data[iteration].residuals else: return None
def error_reduction(self, iteration): if iteration in self._error_reduction: return self._error_reduction[iteration] else: return None def solution_reduction(self, iteration): if iteration in self._solution_reduction: return self._solution_reduction[iteration] else: return None def set_error_reduction(self, iteration, reduction): assert_condition(isinstance(iteration, int) and iteration > 0, ValueError, message="Iteration must be a non-zero positive integer: NOT {}".format(iteration), checking_obj=self) assert_is_instance(reduction, (float, np.ndarray), descriptor="Reduction of Error", checking_obj=self) self._error_reduction[iteration] = copy.copy(reduction) def set_solution_reduction(self, iteration, reduction): assert_condition(isinstance(iteration, int) and iteration > 0, ValueError, "Iteration must be a non-zero positive integer: NOT {}".format(iteration), self) assert_is_instance(reduction, (float, np.ndarray), descriptor="Reduction of Solution", checking_obj=self) self._solution_reduction[iteration] = copy.copy(reduction) @property
[docs] def solutions(self): """Read-only accessor for the stored list of solution data storages. Returns ------- values : :py:class:`list` of :py:class:`.TrajectorySolutionData` or :py:class:`.StepSolutionData` objects """ return self._data
@property
[docs] def time_points(self): """Proxies :py:attr:`.TrajectorySolutionData.time_points`. Returns ------- time_points : :py:class:`numpy.ndarray` or :py:class:`None` :py:class:`None` is returned if no solutions have yet been stored """ if len(self._data) > 0: if self._data_type == TrajectorySolutionData: return self._data[0].time_points else: return np.array([step.time_point for step in self._data], dtype=np.float) else: return None
[docs] def _check_consistency(self): """Check consistency of stored solution data objects. Raises ------ ValueError : If the time points of at least two solution data storage objects differ. """ if len(self._data) > 0: _time_points = self._data[0].time_points for iteration in range(1, len(self._data)): assert_condition(np.array_equal(_time_points, self._data[iteration].time_points), ValueError, message="Time points of one or more stored solution data objects do not match.", checking_obj=self)
__all__ = ['IterativeSolution']