Coverage for pySDC/core/problem.py: 95%
41 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-18 08:18 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-18 08:18 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""
4Description
5-----------
7Module containing the base Problem class for pySDC
8"""
10import logging
12from pySDC.core.common import RegisterParams
15class WorkCounter(object):
16 """
17 Utility class for counting iterations.
19 Contains one attribute `niter` initialized to zero during
20 instantiation, which can be incremented by calling object as
21 a function, e.g
23 >>> count = WorkCounter() # => niter = 0
24 >>> count() # => niter = 1
25 >>> count() # => niter = 2
26 """
28 def __init__(self):
29 self.niter = 0
31 def __call__(self, *args, **kwargs):
32 # *args and **kwargs are necessary for gmres
33 self.niter += 1
35 def decrement(self):
36 self.niter -= 1
38 def __str__(self):
39 return f'{self.niter}'
42class Problem(RegisterParams):
43 """
44 Prototype class for problems, just defines the attributes essential to get started.
46 Parameters
47 ----------
48 init : list of args
49 Argument(s) used to initialize data types.
50 dtype_u : type
51 Variable data type. Should generate a data variable using dtype_u(init).
52 dtype_f : type
53 RHS data type. Should generate a data variable using dtype_f(init).
55 Attributes
56 ----------
57 logger: logging.Logger
58 custom logger for problem-related logging.
59 """
61 logger = logging.getLogger('problem')
62 dtype_u = None
63 dtype_f = None
65 def __init__(self, init):
66 self.work_counters = {} # Dictionary to store WorkCounter objects
67 self.init = init # Initialization parameter to instantiate data types
69 @property
70 def u_init(self):
71 """Generate a data variable for u"""
72 return self.dtype_u(self.init)
74 @property
75 def f_init(self):
76 """Generate a data variable for RHS"""
77 return self.dtype_f(self.init)
79 @classmethod
80 def get_default_sweeper_class(cls):
81 raise NotImplementedError(f'No default sweeper class implemented for {cls} problem!')
83 def setUpFieldsIO(self):
84 """
85 Set up FieldsIO for MPI with the space decomposition of this problem
86 """
87 pass
89 def getOutputFile(self, fileName):
90 raise NotImplementedError(f'No output implemented file for {type(self).__name__}')
92 def processSolutionForOutput(self, u):
93 return u
95 def eval_f(self, u, t):
96 """
97 Abstract interface to RHS computation of the ODE
99 Parameters
100 ----------
101 u : dtype_u
102 Current values.
103 t : float
104 Current time.
106 Returns
107 -------
108 f : dtype_f
109 The RHS values.
110 """
111 raise NotImplementedError('ERROR: problem has to implement eval_f(self, u, t)')
113 def apply_mass_matrix(self, u): # pragma: no cover
114 """Default mass matrix : identity"""
115 return u
117 def generate_scipy_reference_solution(self, eval_rhs, t, u_init=None, t_init=None, **kwargs):
118 """
119 Compute a reference solution using `scipy.solve_ivp` with very small tolerances.
120 Keep in mind that scipy needs the solution to be a one dimensional array. If you are solving something higher
121 dimensional, you need to make sure the function `eval_rhs` takes a flattened one-dimensional version as an input
122 and output, but reshapes to whatever the problem needs for evaluation.
124 The keyword arguments will be passed to `scipy.solve_ivp`. You should consider passing `method='BDF'` for stiff
125 problems and to accelerate that you can pass a function that evaluates the Jacobian with arguments `jac(t, u)`
126 as `jac=jac`.
128 Args:
129 eval_rhs (function): Function evaluate the full right hand side. Must have signature `eval_rhs(float: t, numpy.1darray: u)`
130 t (float): current time
131 u_init (pySDC.implementations.problem_classes.Lorenz.dtype_u): initial conditions for getting the exact solution
132 t_init (float): the starting time
134 Returns:
135 numpy.ndarray: Reference solution
136 """
137 import numpy as np
138 from scipy.integrate import solve_ivp
140 kwargs = {
141 'atol': 100 * np.finfo(float).eps,
142 'rtol': 100 * np.finfo(float).eps,
143 **kwargs,
144 }
145 u_init = self.u_exact(t=0) if u_init is None else u_init * 1.0
146 t_init = 0 if t_init is None else t_init
148 u_shape = u_init.shape
149 return solve_ivp(eval_rhs, (t_init, t), u_init.flatten(), **kwargs).y[:, -1].reshape(u_shape)
151 def get_fig(self):
152 """
153 Get a figure suitable to plot the solution of this problem
155 Returns
156 -------
157 self.fig : matplotlib.pyplot.figure.Figure
158 """
159 raise NotImplementedError
161 def plot(self, u, t=None, fig=None):
162 r"""
163 Plot the solution. Please supply a figure with the same structure as returned by ``self.get_fig``.
165 Parameters
166 ----------
167 u : dtype_u
168 Solution to be plotted
169 t : float
170 Time to display at the top of the figure
171 fig : matplotlib.pyplot.figure.Figure
172 Figure with the correct structure
174 Returns
175 -------
176 None
177 """
178 raise NotImplementedError