Coverage for pySDC/projects/DAE/misc/ProblemDAE.py: 100%
21 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-29 09:02 +0000
1import numpy as np
2from scipy.optimize import root
4from pySDC.core.Problem import ptype, WorkCounter
5from pySDC.projects.DAE.misc.DAEMesh import DAEMesh
8class ptype_dae(ptype):
9 r"""
10 This class implements a generic DAE class and illustrates the interface class for DAE problems.
11 It ensures that all parameters are passed that are needed by DAE sweepers.
13 Parameters
14 ----------
15 nvars : int
16 Number of unknowns of the problem class.
17 newton_tol : float
18 Tolerance for the nonlinear solver.
20 Attributes
21 ----------
22 work_counters : WorkCounter
23 Counts the work, here the number of function calls during the nonlinear solve is logged and stored
24 in work_counters['newton']. The number of each function class of the right-hand side is then stored
25 in work_counters['rhs']
26 """
28 dtype_u = DAEMesh
29 dtype_f = DAEMesh
31 def __init__(self, nvars, newton_tol):
32 """Initialization routine"""
33 super().__init__((nvars, None, np.dtype('float64')))
34 self._makeAttributeAndRegister('nvars', 'newton_tol', localVars=locals(), readOnly=True)
36 self.work_counters['newton'] = WorkCounter()
37 self.work_counters['rhs'] = WorkCounter()
39 def solve_system(self, impl_sys, u0, t):
40 r"""
41 Solver for nonlinear implicit system (defined in sweeper).
43 Parameters
44 ----------
45 impl_sys : callable
46 Implicit system to be solved.
47 u0 : dtype_u
48 Initial guess for solver.
49 t : float
50 Current time :math:`t`.
52 Returns
53 -------
54 me : dtype_u
55 Numerical solution.
56 """
57 me = self.dtype_u(self.init)
59 def implSysFlatten(unknowns, **kwargs):
60 sys = impl_sys(unknowns.reshape(me.shape).view(type(u0)), **kwargs)
61 return sys.flatten()
63 opt = root(
64 implSysFlatten,
65 u0,
66 method='hybr',
67 tol=self.newton_tol,
68 )
69 me[:] = opt.x.reshape(me.shape)
70 self.work_counters['newton'].niter += opt.nfev
71 return me