Coverage for pySDC / helpers / vtkIO.py: 98%

51 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-12 11:13 +0000

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3""" 

4Helper functions for VTK files IO (to be used with Paraview or PyVista) 

5""" 

6 

7import os 

8import vtk 

9from vtkmodules.util import numpy_support 

10import numpy as np 

11 

12 

13def writeToVTR(fileName: str, data, coords, varNames): 

14 """ 

15 Write a data array containing variables from a 3D rectilinear grid into a VTR file. 

16 

17 Parameters 

18 ---------- 

19 fileName : str 

20 Name of the VTR file, can be with or without the .vtr extension. 

21 data : np.4darray 

22 Array containing all the variables with [nVar, nX, nY, nZ] shape. 

23 coords : list[np.1darray] 

24 Coordinates in each direction. 

25 varNames : list[str] 

26 Variable names. 

27 

28 Returns 

29 ------- 

30 fileName : str 

31 Name of the VTR file. 

32 """ 

33 data = np.asarray(data) 

34 nVar, *gridSizes = data.shape 

35 

36 assert len(gridSizes) == 3, "function can be used only for 3D grid data" 

37 assert nVar == len(varNames), "varNames must have as many variable as data" 

38 assert [len(np.ravel(coord)) for coord in coords] == gridSizes, "coordinate size incompatible with data shape" 

39 

40 nX, nY, nZ = gridSizes 

41 vtr = vtk.vtkRectilinearGrid() 

42 vtr.SetDimensions(nX, nY, nZ) 

43 

44 def vect(x): 

45 return numpy_support.numpy_to_vtk(num_array=x, deep=True, array_type=vtk.VTK_FLOAT) 

46 

47 x, y, z = coords 

48 vtr.SetXCoordinates(vect(x)) 

49 vtr.SetYCoordinates(vect(y)) 

50 vtr.SetZCoordinates(vect(z)) 

51 

52 def field(u): 

53 return numpy_support.numpy_to_vtk(num_array=u.ravel(order='F'), deep=True, array_type=vtk.VTK_FLOAT) 

54 

55 pointData = vtr.GetPointData() 

56 for name, u in zip(varNames, data, strict=True): 

57 uVTK = field(u) 

58 uVTK.SetName(name) 

59 pointData.AddArray(uVTK) 

60 

61 writer = vtk.vtkXMLRectilinearGridWriter() 

62 if not fileName.endswith(".vtr"): 

63 fileName += ".vtr" 

64 writer.SetFileName(fileName) 

65 writer.SetInputData(vtr) 

66 writer.Write() 

67 

68 return fileName 

69 

70 

71def readFromVTR(fileName: str): 

72 """ 

73 Read a VTR file into a numpy 4darray 

74 

75 Parameters 

76 ---------- 

77 fileName : str 

78 Name of the VTR file, can be with or without the .vtr extension. 

79 

80 Returns 

81 ------- 

82 data : np.4darray 

83 Array containing all the variables with [nVar, nX, nY, nZ] shape. 

84 coords : list[np.1darray] 

85 Coordinates in each direction. 

86 varNames : list[str] 

87 Variable names. 

88 """ 

89 if not fileName.endswith(".vtr"): 

90 fileName += ".vtr" 

91 assert os.path.isfile(fileName), f"{fileName} does not exist" 

92 

93 reader = vtk.vtkXMLRectilinearGridReader() 

94 reader.SetFileName(fileName) 

95 reader.Update() 

96 

97 vtr = reader.GetOutput() 

98 gridSizes = vtr.GetDimensions() 

99 assert len(gridSizes) == 3, "can only read 3D data" 

100 

101 def vect(x): 

102 return numpy_support.vtk_to_numpy(x) 

103 

104 coords = [vect(vtr.GetXCoordinates()), vect(vtr.GetYCoordinates()), vect(vtr.GetZCoordinates())] 

105 pointData = vtr.GetPointData() 

106 varNames = [pointData.GetArrayName(i) for i in range(pointData.GetNumberOfArrays())] 

107 data = [numpy_support.vtk_to_numpy(pointData.GetArray(name)).reshape(gridSizes, order="F") for name in varNames] 

108 data = np.array(data) 

109 

110 return data, coords, varNames