# coding=utf-8
"""
.. moduleauthor:: Torbjörn Klatt <[email protected]>
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import is_interactive
from pypint.plugins.plotters.i_plotter import IPlotter
from pypint.plugins.plotters import colorline
from pypint.solvers.i_iterative_time_solver import IIterativeTimeSolver
from pypint.solvers.states import ISolverState
from pypint.solvers.diagnosis.norms import supremum_norm
from pypint.utilities import assert_named_argument
from pypint.utilities.logging import LOG
[docs]class SingleSolutionPlotter(IPlotter):
    """Plotter for a single solution of an iterative time solver.
    See Also
    --------
    :py:class:`.IPlotter` : overridden class
    """
    def __init__(self, *args, **kwargs):
        super(SingleSolutionPlotter, self).__init__(args, **kwargs)
        self._solver = None
        self._state = None
        self._nodes = None
        self._errplot = False
        self._residualplot = False
[docs]    def plot(self, *args, **kwargs):
        """Plots the solution and optional also the error for each iteration.
        Parameters
        ----------
        solver : :py:class:`.IIterativeTimeSolver`
            solver instance used to calculate the solution
        state : :py:class:`.ISolverState`
            state containing information to plot
        errplot : :py:class:`bool`
            *(optional)*
            if given and :py:class:`True` also plots the errors for each iteration found in the solution
        residualplot : :py:class:`bool`
            *(optional)*
            if given and :py:class:`True` also plots the residual for each iteration found in the solution
        Raises
        ------
        ValueError
            * if ``solver`` not given and not an :py:class:`.IIterativeTimeSolver`
            * if ``state`` not given and not an :py:class:`.ISolverState`
        """
        super(SingleSolutionPlotter, self).plot(args, **kwargs)
        assert_named_argument('solver', kwargs, types=IIterativeTimeSolver, descriptor="Solver", checking_obj=self)
        assert_named_argument('state', kwargs, types=ISolverState, descriptor="State must be given", checking_obj=self)
        self._solver = kwargs['solver']
        self._state = kwargs['state']
        self._nodes = self._state.first.time_points
        _subplots = 1
        _curr_subplot = 0
        if 'errorplot' in kwargs and kwargs['errorplot']:
            _subplots += 1
            self._errplot = True
        if 'residualplot' in kwargs and kwargs['residualplot']:
            _subplots += 1
            self._residualplot = True
        if self._solver.problem.time_start != self._nodes[0]:
            self._nodes = np.concatenate(([self._solver.problem.time_start], self._nodes))
        if self._solver.problem.time_end != self._nodes[-1]:
            self._nodes = np.concatenate((self._nodes, [self._solver.problem.time_end]))
        if self._errplot or self._residualplot:
            plt.suptitle(r"after {:d} iterations; overall reduction: {:.2e}"
                         .format(len(self._state),
                                 supremum_norm(self._state.solution
                                               .solution_reduction(self._state.last_iteration_index))))
            _curr_subplot += 1
            plt.subplot(_subplots, 1, _curr_subplot)
        self._final_solution()
        plt.title(self._solver.problem.__str__())
        if self._errplot:
            _curr_subplot += 1
            plt.subplot(3, 1, _curr_subplot)
            self._error_plot()
        if self._residualplot:
            _curr_subplot += 1
            plt.subplot(3, 1, _curr_subplot)
            self._residual_plot()
        if self._file_name is not None:
            fig = plt.gcf()
            fig.set_dpi(300)
            fig.set_size_inches((15., 15.))
            LOG.debug("Plotting figure with size (w,h) {:s} inches and {:d} DPI."
                      .format(fig.get_size_inches(), fig.get_dpi()))
            fig.savefig(self._file_name)
        if is_interactive():
            plt.show(block=True)
        else:
            plt.close('all')
 
    def _final_solution(self):
        _solution = np.insert(self._state.last_iteration.solution.values, 0, [self._state.initial.solution.value],
                              axis=0)
        if self._state.last_iteration.solution.numeric_type == np.complex:
            colorline(np.array([_p.real for _p in _solution[:, 0]], dtype=np.float),
                      np.array([_p.imag for _p in _solution[:, 0]], dtype=np.float))
        else:
            plt.plot(self._nodes, _solution[:, 0], label="Solution")
        # TODO fix error checking
        # if problem_has_exact_solution(self._solver.problem, self):
                # and self._state.last_iteration.solution.errors.max() > 1e-2:
            # exact = np.array([self._solver.problem.exact(_t) for _t in self._nodes],
            #                  dtype=self._solver.problem.numeric_type)
            # if exact.dtype == np.complex:
            #     LOG.debug("    plotting exact solution as complex values")
            #     colorline(exact.real, exact.imag)
            # else:
            #     LOG.debug("    plotting exact solution as real values")
            #     plt.plot(self._nodes, exact, label="Exact")
        if self._state.last_iteration.solution.numeric_type == np.complex:
            plt.xlabel("real")
            plt.ylabel("imag")
        else:
            plt.xticks(self._nodes)
            plt.xlabel("integration nodes")
            plt.ylabel(r'$u(t, \phi_t)$')
            plt.xlim(self._nodes[0], self._nodes[-1])
            plt.legend()
        plt.grid(True)
    def _error_plot(self):
        for i in range(0, len(self._state)):
            _error = np.insert(np.array([_e.value for _e in self._state[i].solution.errors]), 0, [0.0], axis=0)
            plt.plot(self._nodes, _error, label=r"Iteraion {:d}".format(i+1))
        plt.xticks(self._nodes)
        plt.xlim(self._nodes[0], self._nodes[-1])
        plt.yscale("log")
        plt.xlabel("integration nodes")
        plt.ylabel(r'absolute error of iterations')
        plt.grid(True)
    def _residual_plot(self):
        for i in range(0, len(self._state)):
            _residual = np.insert(np.array([_r.value for _r in self._state[i].solution.residuals]), 0, [0.0], axis=0)
            plt.plot(self._nodes, _residual, label=r"Iteration {:d}".format(i+1))
        plt.xticks(self._nodes)
        plt.xlim(self._nodes[0], self._nodes[-1])
        plt.yscale("log")
        plt.xlabel("integration nodes")
        plt.ylabel(r'residual')
        plt.grid(True)