Source code for skecg.cs.codec_b

"""
This codec supports adaptive quantization
"""


# std imports
import sys
import math
from decimal import Decimal
from typing import NamedTuple, List
import timeit

# NumPy
import jax
import numpy as np
import jax.numpy as jnp


# CR-Suite libraries
import cr.nimble as crn
import cr.nimble.dsp as crdsp
import cr.sparse as crs
import cr.sparse.dict as crdict
import cr.sparse.block.bsbl as bsbl


# Bitarray
from bitarray import bitarray
from bitarray.util import int2ba, ba2int
from cr.nimble.compression import *

# Entropy coding
import constriction

# Constants for Encoder Design
N_BITS = 12
M_BITS = 12
D_BITS = 6
W_BITS = 8
Q_BITS = 4
Q_MAX = 6
Q_MIN = 0

SEC_MEAN_BITS = 16
SEC_STD_BITS = 16
SEC_Q_BITS = 3
SEC_RNG_BITS = 4
SEC_WORD_BITS = 16


######################################################################
#                       ENCODER
######################################################################

[docs]class EncoderParams(NamedTuple): """The set of parameters to configure the ECG encoder """ key: jax.numpy.ndarray "PRNG Key for generating sensing matrix" n: int "number of samples per window" m : int "number of measurements per window" d: int "Number of ones per column in sparse binary sensing matrix" w: int "number of windows in each frame of the signal" adaptive: bool "A flag to indicate if the quantization is adaptive" q: int "quantization parameter if the quantization is fixed" q_nmse_limit: Decimal "NMSE limit for the quantization step" c_nmse_limit: Decimal "NMSE limit for the clipping step" @property def frame_length(self): """Length of each frame""" return self.n * self.w def __eq__(self, other): """Checks if two instances of EncoderParams are equal This is useful for verifying if the encoding parameters were properly serialized and de-serialized in a bitstream. """ if isinstance(other, self.__class__): if not jnp.all(self.key == other.key): return False if not self.n == other.n: return False if not self.m == other.m: return False if not self.d == other.d: return False if not self.w == other.w: return False if not self.adaptive == other.adaptive: return False if self.adaptive: if not self.q_nmse_limit == other.q_nmse_limit: return False else: if not self.q == other.q: return False if not self.c_nmse_limit == other.c_nmse_limit: return False return True else: return False
[docs]class EncodedFrame(NamedTuple): """Information about each encoded frame""" n_measurements: int "Number of measurements" n_windows: int "Number of windows" max_val: int "Maximum value of measurements" min_val: int "Minimum value of measurements" mean_val: int "Mean value of measurements (integer)" std_val: int "Standard deviation of measurements (integer)" q : int "Quantization parameter used for encoding the frame data" rng_mult: int "Range multiplier used for restricting the set of values" n_words : int "Number of words in the compressed entropy coded frame payload" n_header_bits: int "Number of bits in the frame header" n_payload_bits: int "Number of bits in the frame payload" n_bits: int "Total number of bits in the encoded frame" q_nmse: float "NMSE for the quantization step" c_nmse: float "NMSE for the clipping step" qc_nmse: float "NMSE for combined quantization and clipping" qc_snr: float "SNR for combined quantization and clipping" @property def overhead(self): "Fractional overhead of the header bits" return self.n_header_bits / self.n_payload_bits
[docs]class EncodedStream(NamedTuple): "Information about the encoded bitstream" n_samples: int "Number of samples in the bitstream" n_windows: int "Number of windows in the bitstream" n_frames: int "Number of frames in the bitstream" n_measurements: int "Number of measurements in the bitstream" n_header_bits: int "Number of bits in the stream header" n_bits: int "Total number of bits in the bitstream" frames: List[EncodedFrame] "List of encoded frames" @property def q_vals(self): """List of quantization parameters across all frames""" return np.array([frame.q for frame in self.frames]) @property def mean_vals(self): """List of measurement mean values across all frames""" return np.array([frame.mean_val for frame in self.frames]) @property def std_vals(self): """List of measurement standard deviation values across all frames""" return np.array([frame.std_val for frame in self.frames]) @property def rng_mults(self): """List of range multipliers across all frames""" return np.array([frame.rng_mult for frame in self.frames]) @property def overheads(self): """List of fractional overheads of frame headers across all frames""" return np.array([frame.overhead for frame in self.frames]) @property def q_nmses(self): """List of NMSE for quantization step across all frames""" return np.array([frame.q_nmse for frame in self.frames]) @property def c_nmses(self): """List of NMSE for clipping step across all frames""" return np.array([frame.c_nmse for frame in self.frames]) @property def qc_nmses(self): """List of NMSE for quantization+clipping step across all frames""" return np.array([frame.qc_nmse for frame in self.frames]) @property def qc_snrs(self): """List of SNR values for quantization+clipping step across all frames""" return np.array([frame.qc_snr for frame in self.frames]) @property def overhead_bits(self): """List of overhead bits across all frames""" hbits = np.sum([frame.n_header_bits for frame in self.frames]) hbits += self.n_header_bits return hbits @property def total_overhead(self): """Total fractional overhead of header bits across all frames""" return self.overhead_bits / self.n_bits @property def compressed_bits(self): """Total compressed bits in the encoded bitstream""" return self.n_bits @property def uncompressed_bits(self): """Total uncompressed bits in the original ECG signal""" return self.n_samples * 11 @property def cr(self): """Compression ratio""" return crn.compression_ratio(self.uncompressed_bits , self.compressed_bits) @property def pss(self): """Percentage space savings""" return crn.percent_space_saving(self.uncompressed_bits, self.compressed_bits) @property def bps(self): """Bits per sample""" return self.compressed_bits / self.n_samples @property def bpm(self): """Bits per measurement""" return self.compressed_bits / self.n_measurements def __str__(self): s = [] s.append(f'n_samples={self.n_samples}') s.append(f'n_measurements={self.n_measurements}') s.append(f'n_windows={self.n_windows}') s.append(f'n_frames={self.n_frames}') s.append(f'n_header_bits={self.n_header_bits}') s.append(f'overhead_bits={self.overhead_bits}') s.append(f'compressed_bits={self.compressed_bits}') s.append(f'uncompressed_bits={self.uncompressed_bits}') s.append(f'compression_ratio={self.cr:.2f}') s.append(f'percent_space_saving={self.pss:.1f} %') s.append(f'bps={self.bps:.2f}') s.append(f'bpm={self.bpm:.2f}') s.append(f'overhead={self.total_overhead * 100:.2f} %') return '\n'.join(s)
[docs]class EncodedData(NamedTuple): """Encoded bitstream and encoding summary""" info: EncodedStream "Summarized information about the encoded bitstream" y: np.ndarray "Measurement values array (across all frames)" bits: bitarray "Encoded (compressed) bitstream"
def serialize_encoder_params(params: EncoderParams): """Serializes encoding parameters into a bitarray """ a = bitarray() key = params.key.to_py() a.extend(int2ba(int(key[0]), 32)) a.extend(int2ba(int(key[1]), 32)) a.extend(int2ba(params.n, N_BITS)) a.extend(int2ba(params.m, M_BITS)) a.extend(int2ba(params.d, D_BITS)) a.extend(int2ba(params.w, W_BITS)) a.append(params.adaptive) if not params.adaptive: a.extend(int2ba(params.q, Q_BITS)) else: s, digits, exp = params.q_nmse_limit.as_tuple() a.extend(int2ba(digits[0], 4)) a.extend(int2ba(-exp, 4)) # encoded the clipping limit s, digits, exp = params.c_nmse_limit.as_tuple() a.extend(int2ba(digits[0], 4)) a.extend(int2ba(-exp, 4)) return a
[docs]def build_sensor(params: EncoderParams): """ Constructs a sparse binary sensing matrix based on the encoding parameters. """ Phi = crdict.sparse_binary_mtx(params.key, params.m, params.n, d=params.d, normalize_atoms=False) return Phi
[docs]def sense(params, Phi, ecg): """Performs windowing, compressing sensing and flattening of ECG signal """ X = crn.vec_to_windows(ecg, params.n) # Measurements Y = Phi @ X # Convert to numpy Y_np = np.array(Y).astype(int) y = Y_np.flatten(order='F') return y
[docs]def encode(params: EncoderParams, ecg: np.ndarray): """Encodes ECG data into a bitstream This function: * Splits the ECG signal into frames * Performs windowing, compressing sensing and flattening on each frame. * Performs entropy coding of measurements for each frame. * Serializes stream header, frame headers and frame payloads into a bitstream. """ stream = bitarray() stream.extend(serialize_encoder_params(params)) # fill to the multiple of 8 stream.fill() n_header_bits = len(stream) # sensing matrix Phi = build_sensor(params) # measurements y = sense(params, Phi, ecg) n_measurements = y.size n_windows = n_measurements // params.m n_samples = params.n * n_windows # compute number of frames n_frames = math.ceil(n_windows / params.w) # length of each frame of measurements sl = params.m * params.w # work frame by frame start = 0 frames = [] for i_sec in range(n_frames): # print(f'Encoding frame {i_sec}') sec_info, bits = encode_frame(params, y[start:start+sl]) start += sl stream.extend(bits) frames.append(sec_info) n_bits = len(stream) info = EncodedStream(n_samples=n_samples, n_windows=n_windows, n_frames=n_frames, n_measurements=n_measurements, n_header_bits=n_header_bits, n_bits=n_bits, frames=frames) data = EncodedData(info=info, y=y, bits=stream) return data
[docs]def encode_frame(params: EncoderParams, y: np.ndarray): """Encodes a single frame of ECG signal """ n_measurements=len(y) n_windows = n_measurements // params.m q_nmse_limit = float(params.q_nmse_limit) c_nmse_limit = float(params.c_nmse_limit) q = params.q if params.adaptive: for q in range(Q_MAX, Q_MIN, -1): yq = y >> q yhat = yq << q q_nmse = crn.normalized_root_mse(y, yhat) if q_nmse <= q_nmse_limit: # we have achieved the desired quantization break else: yq = y >> q yhat = yq << q q_nmse = crn.normalized_root_mse(y, yhat) max_val = np.max(yq) min_val = np.min(yq) mean_val = int(np.round(yq.mean())) std_val = int(np.ceil(yq.std())) # make sure that std-val is positive std_val = std_val if std_val > 0 else 1 s_max = max(np.abs(max_val), np.abs(min_val)) for rng_mult in range(2,9): a_min = int(mean_val - rng_mult * std_val) a_max = int(mean_val + rng_mult * std_val) yc = np.clip(yq, a_min, a_max) c_nmse = crn.normalized_root_mse(yq, yc) if c_nmse <= c_nmse_limit: break # Measure the overall SNR yhat = yc << q qc_nmse = crn.normalized_root_mse(y, yhat) qc_snr = crn.signal_noise_ratio(y, yhat) model = constriction.stream.model.QuantizedGaussian(a_min, a_max, mean=mean_val, std=std_val) encoder = constriction.stream.stack.AnsCoder() encoder.encode_reverse(yc, model) compressed = encoder.get_compressed() # number of words of compressed bits n_words = len(compressed) # start encoding stream = bitarray() stream.extend(int2ba(mean_val, SEC_MEAN_BITS, signed=True)) stream.extend(int2ba(std_val, SEC_STD_BITS)) stream.extend(int2ba(q, SEC_Q_BITS)) stream.extend(int2ba(rng_mult, SEC_RNG_BITS)) stream.extend(int2ba(n_windows, W_BITS)) stream.extend(int2ba(n_words, SEC_WORD_BITS)) stream.fill() n_header_bits = len(stream) for word in compressed: stream.extend(int2ba(int(word), 32)) n_bits = len(stream) n_payload_bits = n_bits - n_header_bits info = EncodedFrame(n_measurements=n_measurements, n_windows=n_windows, max_val=max_val, min_val=min_val, mean_val=mean_val, std_val=std_val, q=q, rng_mult=rng_mult, n_words=n_words, n_header_bits=n_header_bits, n_payload_bits=n_payload_bits, n_bits=n_bits, q_nmse=float(q_nmse), c_nmse=float(c_nmse), qc_nmse=float(qc_nmse), qc_snr=float(qc_snr)) return info, stream
###################################################################### # DECODER ###################################################################### def deserialize_encoder_params(bits: bitarray, pos=0): """Reads the encoding parameters """ key0 = ba2int(bits[pos:pos+32]) pos = 32 key1 = ba2int(bits[pos:pos+32]) key = jnp.array([key0, key1], dtype=jnp.uint32) pos += 32 n = ba2int(bits[pos:pos+N_BITS]) pos += N_BITS m = ba2int(bits[pos:pos+M_BITS]) pos += M_BITS d = ba2int(bits[pos:pos+D_BITS]) pos += D_BITS w = ba2int(bits[pos:pos+W_BITS]) pos += W_BITS adaptive = bool(bits[pos]) pos += 1 q = 0 if not adaptive: q = ba2int(bits[pos:pos+Q_BITS]) pos += 4 q_nmse_limit = Decimal((0, (0,), 0)) else: digit = ba2int(bits[pos:pos+4]) pos += 4 exp = ba2int(bits[pos:pos+4]) pos += 4 q_nmse_limit = Decimal((0, (digit,), -exp)) digit = ba2int(bits[pos:pos+4]) pos += 4 exp = ba2int(bits[pos:pos+4]) pos += 4 c_nmse_limit = Decimal((0, (digit,), -exp)) return EncoderParams(key=key, n=n, m=m, d=d, w=w, adaptive=adaptive, q=q, q_nmse_limit=q_nmse_limit, c_nmse_limit=c_nmse_limit), pos def next_byte_pos(pos): return (pos + 7) & (-8)
[docs]class DecodedData(NamedTuple): """Decoded ECG signal and decoding summary """ x: np.ndarray "Decoded ECG signal" y_hat: np.ndarray "Decoded measurements (after entropy decoding and inverse quantization)" r_times: np.ndarray "List of reconstruction times for each frame" r_iters: np.ndarray "List of number of iterations for reconstruction of each frame" @property def total_time(self): """Total reconstruction time""" return np.sum(self.r_times)
[docs]def decode(bits: bitarray, block_size=32): """Decodes an encoded bitstream This function: * reads the stream header * reads the frame headers and frame payloads one by one * decode each frame * combine them together to form the decoded bitstream The input is a bitarray. The only parameter is the block size for the BSBL reconstruction algorithm. """ # read the parameters params, pos = deserialize_encoder_params(bits) # extend the pos to next multiple of 8 pos = next_byte_pos(pos) y_hat = read_measurements(params, bits, pos) # Arrange measurements into column vectors Yhat = crn.vec_to_windows(jnp.asarray(y_hat, dtype=float), params.m) n_windows = Yhat.shape[1] options = bsbl.bsbl_bo_options(max_iters=20) X_hat = np.zeros((params.n, n_windows)) r_times = np.zeros(n_windows) r_iters = np.zeros(n_windows, dtype=int) # sensing matrix Phi = build_sensor(params) DPhi = Phi.todense() # Start decoding for i in range(n_windows): y = Yhat[:, i] start = timeit.default_timer() sol = bsbl.bsbl_bo_np_jit(DPhi, y, block_size, options=options) stop = timeit.default_timer() rtime = stop - start x_hat = sol.x X_hat[:, i] = x_hat r_times[i] = rtime r_iters[i] = sol.iterations print(f'[{i}/{n_windows}], time: {rtime:.2f} sec') x = X_hat.flatten(order='F') return DecodedData(x=x, y_hat=y_hat, r_times=r_times, r_iters=r_iters)
def read_measurements(params, bits, pos): """Performs entropy decoding and inverse quantization of measurement values for all frames """ # total bits n_bits = len(bits) yhats = [] while pos < n_bits: # decode a frame # read frame header mean_val = ba2int(bits[pos:pos+SEC_MEAN_BITS], signed=True) pos += SEC_MEAN_BITS std_val = ba2int(bits[pos:pos+SEC_STD_BITS]) pos += SEC_STD_BITS q = ba2int(bits[pos:pos+SEC_Q_BITS]) pos += SEC_Q_BITS rng_mult = ba2int(bits[pos:pos+SEC_RNG_BITS]) pos += SEC_RNG_BITS n_windows = ba2int(bits[pos:pos+W_BITS]) pos += W_BITS n_words = ba2int(bits[pos:pos+SEC_WORD_BITS]) pos += SEC_WORD_BITS pos = next_byte_pos(pos) # print(mean_val, std_val, q, rng_mult, n_words) compressed = [] for i in range(n_words): word = ba2int(bits[pos:pos+32]) pos += 32 compressed.append(word) a_min = int(mean_val - rng_mult * std_val) a_max = int(mean_val + rng_mult * std_val) compressed = np.array(compressed, dtype=np.uint32) model = constriction.stream.model.QuantizedGaussian(a_min, a_max, mean=mean_val, std=std_val) # Decode the message: ans_decoder = constriction.stream.stack.AnsCoder(compressed) # number of measurements in the frame sl = params.m * n_windows yc = ans_decoder.decode(model, sl) yhat = yc << q yhats.append(yhat) return np.concatenate(yhats) ###################################################################### # COMPARISON ######################################################################
[docs]class CompressionStats(NamedTuple): """Compression statistics""" u_bits: int "Uncompressed bits count" c_bits: int "compressed bits count" bpm: float "bits per measurement" bps: float "bits per sample" cr: float "compression ratio" pss: float "percentage space savings" snr: float "signal to noise ratio (dB)" prd: float "percent root mean square difference" nmse: float "normalized mean square difference" qs : float "Quality score" rtime: float "reconstruction time" qc_snr: float "signal to noise ratio (dB) for quantization+clipping" qc_prd: float "percent root mean square difference for quantization+clipping" qc_nmse: float "NMSE for quantization+clipping"
[docs]def compression_stats(ecg, coded_ecg, decoded_ecg): """Computes the compression statistics from the original ECG signal, encoded bitstream and decoded signal """ info = coded_ecg.info n_samples = info.n_samples n_windows = info.n_windows n_measurements = info.n_measurements y = coded_ecg.y uncompressed_bits = n_samples * 11 compressed_bits = len(coded_ecg.bits) ratio = crn.compression_ratio(uncompressed_bits, compressed_bits) pss = crn.percent_space_saving(uncompressed_bits, compressed_bits) bpm = compressed_bits / n_measurements bps = compressed_bits/ n_samples y_max = np.max(np.abs(y)) print(f'Uncompressed bits: {uncompressed_bits} Compressed bits: {compressed_bits}, ratio: {ratio:.2f}x') print(f'bits per measurement in compressed data: {bpm:.2f}') print(f'bits per measurement in cs measurements: {np.round(np.log2(2* y_max + 1))}') print(f'Compressed bits per sample: {bps:.2f}') rtime = decoded_ecg.total_time x = ecg[:n_samples] x_hat = decoded_ecg.x snr = crn.signal_noise_ratio(x, x_hat) prd = crn.percent_rms_diff(x, x_hat) nmse = crn.normalized_mse(x, x_hat) qs = float(ratio * 100 / prd) print(f'SNR: {snr:.2f} dB, PRD: {prd:.1f}%, QS: {qs:.5f}, Time: {rtime:.2f} sec') # measurement SNR y = coded_ecg.y y_hat = decoded_ecg.y_hat qc_snr = crn.signal_noise_ratio(y, y_hat) qc_prd = crn.percent_rms_diff(y, y_hat) qc_nmse = crn.normalized_mse(y, y_hat) return CompressionStats( u_bits=uncompressed_bits, c_bits=compressed_bits, cr=ratio, pss=pss, bpm=bpm, bps=bps, snr=float(snr), prd=float(prd), nmse=float(nmse), qs=qs, rtime=rtime, qc_snr=qc_snr, qc_prd=qc_prd, qc_nmse=qc_nmse )