Coverage for pySDC/projects/compression/compression_convergence_controller.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-29 09:02 +0000

1from pySDC.core.ConvergenceController import ConvergenceController 

2import numpy as np 

3 

4np.bool = np.bool_ 

5import libpressio 

6 

7 

8class Compression(ConvergenceController): 

9 def setup(self, controller, params, description, **kwargs): 

10 default_compressor_args = { 

11 # configure which compressor to use 

12 "compressor_id": "sz3", 

13 # configure the set of metrics to be gathered 

14 "early_config": {"pressio:metric": "composite", "composite:plugins": ["time", "size", "error_stat"]}, 

15 # configure SZ 

16 "compressor_config": { 

17 "pressio:abs": 1e-10, 

18 }, 

19 } 

20 

21 defaults = { 

22 'control_order': 0, 

23 **super().setup(controller, params, description, **kwargs), 

24 'compressor_args': {**default_compressor_args, **params.get('compressor_args', {})}, 

25 'min_buffer_length': 12, 

26 } 

27 

28 self.compressor = libpressio.PressioCompressor.from_config(defaults['compressor_args']) 

29 

30 return defaults 

31 

32 def post_iteration_processing(self, controller, S, **kwargs): 

33 """ 

34 Replace the solution by the compressed value 

35 """ 

36 assert len(S.levels) == 1 

37 lvl = S.levels[0] 

38 prob = lvl.prob 

39 nodes = np.append(0, lvl.sweep.coll.nodes) 

40 

41 encode_buffer = np.zeros(max([len(lvl.u[0]), self.params.min_buffer_length])) 

42 decode_buffer = np.zeros_like(encode_buffer) 

43 

44 for i in range(len(lvl.u)): 

45 encode_buffer[: len(lvl.u[i])] = lvl.u[i][:] 

46 comp_data = self.compressor.encode(encode_buffer) 

47 decode_buffer = self.compressor.decode(comp_data, decode_buffer) 

48 

49 lvl.u[i][:] = decode_buffer[: len(lvl.u[i])] 

50 lvl.f[i] = prob.eval_f(lvl.u[i], lvl.time + lvl.dt * nodes[i]) 

51 

52 # metrics = self.compressor.get_metrics() 

53 # print(metrics)