Coverage for pySDC/core/problem.py: 97%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 14:51 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Description
5-----------
7Module containing the base Problem class for pySDC
8"""
10import logging
12from pySDC.core.common import RegisterParams
15class WorkCounter(object):
16 """
17 Utility class for counting iterations.
19 Contains one attribute `niter` initialized to zero during
20 instantiation, which can be incremented by calling object as
21 a function, e.g
23 >>> count = WorkCounter() # => niter = 0
24 >>> count() # => niter = 1
25 >>> count() # => niter = 2
26 """
28 def __init__(self):
29 self.niter = 0
31 def __call__(self, *args, **kwargs):
32 # *args and **kwargs are necessary for gmres
33 self.niter += 1
35 def decrement(self):
36 self.niter -= 1
38 def __str__(self):
39 return f'{self.niter}'
42class Problem(RegisterParams):
43 """
44 Prototype class for problems, just defines the attributes essential to get started.
46 Parameters
47 ----------
48 init : list of args
49 Argument(s) used to initialize data types.
50 dtype_u : type
51 Variable data type. Should generate a data variable using dtype_u(init).
52 dtype_f : type
53 RHS data type. Should generate a data variable using dtype_f(init).
55 Attributes
56 ----------
57 logger: logging.Logger
58 custom logger for problem-related logging.
59 """
61 logger = logging.getLogger('problem')
62 dtype_u = None
63 dtype_f = None
65 def __init__(self, init):
66 self.work_counters = {} # Dictionary to store WorkCounter objects
67 self.init = init # Initialization parameter to instantiate data types
69 @property
70 def u_init(self):
71 """Generate a data variable for u"""
72 return self.dtype_u(self.init)
74 @property
75 def f_init(self):
76 """Generate a data variable for RHS"""
77 return self.dtype_f(self.init)
79 @classmethod
80 def get_default_sweeper_class(cls):
81 raise NotImplementedError(f'No default sweeper class implemented for {cls} problem!')
83 def eval_f(self, u, t):
84 """
85 Abstract interface to RHS computation of the ODE
87 Parameters
88 ----------
89 u : dtype_u
90 Current values.
91 t : float
92 Current time.
94 Returns
95 -------
96 f : dtype_f
97 The RHS values.
98 """
99 raise NotImplementedError('ERROR: problem has to implement eval_f(self, u, t)')
101 def apply_mass_matrix(self, u): # pragma: no cover
102 """Default mass matrix : identity"""
103 return u
105 def generate_scipy_reference_solution(self, eval_rhs, t, u_init=None, t_init=None, **kwargs):
106 """
107 Compute a reference solution using `scipy.solve_ivp` with very small tolerances.
108 Keep in mind that scipy needs the solution to be a one dimensional array. If you are solving something higher
109 dimensional, you need to make sure the function `eval_rhs` takes a flattened one-dimensional version as an input
110 and output, but reshapes to whatever the problem needs for evaluation.
112 The keyword arguments will be passed to `scipy.solve_ivp`. You should consider passing `method='BDF'` for stiff
113 problems and to accelerate that you can pass a function that evaluates the Jacobian with arguments `jac(t, u)`
114 as `jac=jac`.
116 Args:
117 eval_rhs (function): Function evaluate the full right hand side. Must have signature `eval_rhs(float: t, numpy.1darray: u)`
118 t (float): current time
119 u_init (pySDC.implementations.problem_classes.Lorenz.dtype_u): initial conditions for getting the exact solution
120 t_init (float): the starting time
122 Returns:
123 numpy.ndarray: Reference solution
124 """
125 import numpy as np
126 from scipy.integrate import solve_ivp
128 kwargs = {
129 'atol': 100 * np.finfo(float).eps,
130 'rtol': 100 * np.finfo(float).eps,
131 **kwargs,
132 }
133 u_init = self.u_exact(t=0) if u_init is None else u_init * 1.0
134 t_init = 0 if t_init is None else t_init
136 u_shape = u_init.shape
137 return solve_ivp(eval_rhs, (t_init, t), u_init.flatten(), **kwargs).y[:, -1].reshape(u_shape)
139 def get_fig(self):
140 """
141 Get a figure suitable to plot the solution of this problem
143 Returns
144 -------
145 self.fig : matplotlib.pyplot.figure.Figure
146 """
147 raise NotImplementedError
149 def plot(self, u, t=None, fig=None):
150 r"""
151 Plot the solution. Please supply a figure with the same structure as returned by ``self.get_fig``.
153 Parameters
154 ----------
155 u : dtype_u
156 Solution to be plotted
157 t : float
158 Time to display at the top of the figure
159 fig : matplotlib.pyplot.figure.Figure
160 Figure with the correct structure
162 Returns
163 -------
164 None
165 """
166 raise NotImplementedError