Coverage for pySDC / projects / parallelSDC_reloaded / convergence.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-03 10:35 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Created on Tue Jan 9 14:44:41 2024 

5 

6Generate convergence plots on Dahlquist for SDC with given parameters 

7""" 

8 

9import numpy as np 

10from pySDC.projects.parallelSDC_reloaded.utils import getParamsRK, getParamsSDC, solutionSDC, plt 

11 

12SCRIPT = __file__.split('/')[-1].split('.')[0] 

13 

14# Script parameters 

15lam = 1j 

16tEnd = 2 * np.pi 

17nStepsList = np.array([2, 5, 10, 20, 50, 100, 200, 500, 1000]) 

18dtVals = tEnd / nStepsList 

19 

20 

21def getError(uNum, uRef): 

22 if uNum is None: # pragma: no cover 

23 return np.inf 

24 return np.linalg.norm(uRef - uNum[:, 0], np.inf) 

25 

26 

27# Collocation parameters 

28nNodes = 4 

29nodeType = "LEGENDRE" 

30quadType = "RADAU-RIGHT" 

31sweepType = "MIN-SR-NS" 

32 

33# Schemes parameters 

34schemes = [("RK4", None), ("ESDIRK43", None), *[(sweepType, i) for i in [1, 2, 3, 4]][:1]] 

35 

36styles = [ 

37 dict(ls=":", c="gray"), 

38 dict(ls="-.", c="gray"), 

39 dict(ls="-", marker='o'), 

40 dict(ls="-", marker='>'), 

41 dict(ls="-", marker='s'), 

42 dict(ls="-", marker='^'), 

43 dict(ls="-", marker='*'), 

44] 

45 

46# ----------------------------------------------------------------------------- 

47# Script execution 

48# ----------------------------------------------------------------------------- 

49plt.figure() 

50for (qDelta, nSweeps), style in zip(schemes, styles, strict=False): 

51 if nSweeps is None: 

52 params = getParamsRK(qDelta) 

53 label = None 

54 else: 

55 params = getParamsSDC(quadType, nNodes, qDelta, nSweeps, nodeType) 

56 label = f"$K={nSweeps}$" 

57 errors = [] 

58 

59 for nSteps in nStepsList: 

60 uNum, counters, parallel = solutionSDC(tEnd, nSteps, params, 'DAHLQUIST', lambdas=np.array([lam])) 

61 

62 tVals = np.linspace(0, tEnd, nSteps + 1) 

63 uExact = np.exp(lam * tVals) 

64 

65 err = getError(uNum, uExact) 

66 errors.append(err) 

67 

68 plt.loglog(dtVals, errors, **style, label=label) 

69 if nSweeps is not None: 

70 plt.loglog(dtVals, (0.1 * dtVals) ** nSweeps, '--', c='gray', lw=1.5) 

71 

72plt.title(sweepType) 

73plt.legend() 

74plt.xlabel(r"$\Delta{t}$") 

75plt.ylabel(r"$L_\infty$ error") 

76plt.grid(True) 

77plt.tight_layout()