Coverage for pySDC/projects/compression/compression_convergence_controller.py: 100%
23 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-09 14:59 +0000
1from pySDC.core.convergence_controller import ConvergenceController
2import numpy as np
4np.bool = np.bool_
5import libpressio
8class Compression(ConvergenceController):
9 def setup(self, controller, params, description, **kwargs):
10 default_compressor_args = {
11 # configure which compressor to use
12 "compressor_id": "sz3",
13 # configure the set of metrics to be gathered
14 "early_config": {"pressio:metric": "composite", "composite:plugins": ["time", "size", "error_stat"]},
15 # configure SZ
16 "compressor_config": {
17 "pressio:abs": 1e-10,
18 },
19 }
21 defaults = {
22 'control_order': 0,
23 **super().setup(controller, params, description, **kwargs),
24 'compressor_args': {**default_compressor_args, **params.get('compressor_args', {})},
25 'min_buffer_length': 12,
26 }
28 self.compressor = libpressio.PressioCompressor.from_config(defaults['compressor_args'])
30 return defaults
32 def post_iteration_processing(self, controller, S, **kwargs):
33 """
34 Replace the solution by the compressed value
35 """
36 assert len(S.levels) == 1
37 lvl = S.levels[0]
38 prob = lvl.prob
39 nodes = np.append(0, lvl.sweep.coll.nodes)
41 encode_buffer = np.zeros(max([len(lvl.u[0]), self.params.min_buffer_length]))
42 decode_buffer = np.zeros_like(encode_buffer)
44 for i in range(len(lvl.u)):
45 encode_buffer[: len(lvl.u[i])] = lvl.u[i][:]
46 comp_data = self.compressor.encode(encode_buffer)
47 decode_buffer = self.compressor.decode(comp_data, decode_buffer)
49 lvl.u[i][:] = decode_buffer[: len(lvl.u[i])]
50 lvl.f[i] = prob.eval_f(lvl.u[i], lvl.time + lvl.dt * nodes[i])
52 # metrics = self.compressor.get_metrics()
53 # print(metrics)