Coverage for pySDC/projects/DAE/misc/ProblemDAE.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1import numpy as np 

2from scipy.optimize import root 

3 

4from pySDC.core.Problem import ptype, WorkCounter 

5from pySDC.projects.DAE.misc.DAEMesh import DAEMesh 

6 

7 

8class ptype_dae(ptype): 

9 r""" 

10 This class implements a generic DAE class and illustrates the interface class for DAE problems. 

11 It ensures that all parameters are passed that are needed by DAE sweepers. 

12 

13 Parameters 

14 ---------- 

15 nvars : int 

16 Number of unknowns of the problem class. 

17 newton_tol : float 

18 Tolerance for the nonlinear solver. 

19 

20 Attributes 

21 ---------- 

22 work_counters : WorkCounter 

23 Counts the work, here the number of function calls during the nonlinear solve is logged and stored 

24 in work_counters['newton']. The number of each function class of the right-hand side is then stored 

25 in work_counters['rhs'] 

26 """ 

27 

28 dtype_u = DAEMesh 

29 dtype_f = DAEMesh 

30 

31 def __init__(self, nvars, newton_tol): 

32 """Initialization routine""" 

33 super().__init__((nvars, None, np.dtype('float64'))) 

34 self._makeAttributeAndRegister('nvars', 'newton_tol', localVars=locals(), readOnly=True) 

35 

36 self.work_counters['newton'] = WorkCounter() 

37 self.work_counters['rhs'] = WorkCounter() 

38 

39 def solve_system(self, impl_sys, u0, t): 

40 r""" 

41 Solver for nonlinear implicit system (defined in sweeper). 

42 

43 Parameters 

44 ---------- 

45 impl_sys : callable 

46 Implicit system to be solved. 

47 u0 : dtype_u 

48 Initial guess for solver. 

49 t : float 

50 Current time :math:`t`. 

51 

52 Returns 

53 ------- 

54 me : dtype_u 

55 Numerical solution. 

56 """ 

57 me = self.dtype_u(self.init) 

58 

59 def implSysFlatten(unknowns, **kwargs): 

60 sys = impl_sys(unknowns.reshape(me.shape).view(type(u0)), **kwargs) 

61 return sys.flatten() 

62 

63 opt = root( 

64 implSysFlatten, 

65 u0, 

66 method='hybr', 

67 tol=self.newton_tol, 

68 ) 

69 me[:] = opt.x.reshape(me.shape) 

70 self.work_counters['newton'].niter += opt.nfev 

71 return me