Coverage for pySDC/projects/parallelSDC_reloaded/convergence.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-16 14:51 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Created on Tue Jan 9 14:44:41 2024 

5 

6Generate convergence plots on Dahlquist for SDC with given parameters 

7""" 

8import numpy as np 

9from pySDC.projects.parallelSDC_reloaded.utils import getParamsRK, getParamsSDC, solutionSDC, plt 

10 

11SCRIPT = __file__.split('/')[-1].split('.')[0] 

12 

13# Script parameters 

14lam = 1j 

15tEnd = 2 * np.pi 

16nStepsList = np.array([2, 5, 10, 20, 50, 100, 200, 500, 1000]) 

17dtVals = tEnd / nStepsList 

18 

19 

20def getError(uNum, uRef): 

21 if uNum is None: # pragma: no cover 

22 return np.inf 

23 return np.linalg.norm(uRef - uNum[:, 0], np.inf) 

24 

25 

26# Collocation parameters 

27nNodes = 4 

28nodeType = "LEGENDRE" 

29quadType = "RADAU-RIGHT" 

30sweepType = "MIN-SR-NS" 

31 

32# Schemes parameters 

33schemes = [("RK4", None), ("ESDIRK43", None), *[(sweepType, i) for i in [1, 2, 3, 4]][:1]] 

34 

35styles = [ 

36 dict(ls=":", c="gray"), 

37 dict(ls="-.", c="gray"), 

38 dict(ls="-", marker='o'), 

39 dict(ls="-", marker='>'), 

40 dict(ls="-", marker='s'), 

41 dict(ls="-", marker='^'), 

42 dict(ls="-", marker='*'), 

43] 

44 

45# ----------------------------------------------------------------------------- 

46# Script execution 

47# ----------------------------------------------------------------------------- 

48plt.figure() 

49for (qDelta, nSweeps), style in zip(schemes, styles): 

50 if nSweeps is None: 

51 params = getParamsRK(qDelta) 

52 label = None 

53 else: 

54 params = getParamsSDC(quadType, nNodes, qDelta, nSweeps, nodeType) 

55 label = f"$K={nSweeps}$" 

56 errors = [] 

57 

58 for nSteps in nStepsList: 

59 uNum, counters, parallel = solutionSDC(tEnd, nSteps, params, 'DAHLQUIST', lambdas=np.array([lam])) 

60 

61 tVals = np.linspace(0, tEnd, nSteps + 1) 

62 uExact = np.exp(lam * tVals) 

63 

64 err = getError(uNum, uExact) 

65 errors.append(err) 

66 

67 plt.loglog(dtVals, errors, **style, label=label) 

68 if nSweeps is not None: 

69 plt.loglog(dtVals, (0.1 * dtVals) ** nSweeps, '--', c='gray', lw=1.5) 

70 

71plt.title(sweepType) 

72plt.legend() 

73plt.xlabel(r"$\Delta{t}$") 

74plt.ylabel(r"$L_\infty$ error") 

75plt.grid(True) 

76plt.tight_layout()