hadt-api / preproc.py
GitHub Actions
Sync API from main repo
947b4e0
from tslearn.utils import to_time_series_dataset
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
import pickle
from wfdb import processing
from sklearn import preprocessing
from scipy.signal import resample
from io import StringIO
import numpy as np
import pandas as pd
def preproc(X):
# to be called in inference/api
in_shape = X.shape
if X.shape[1] != 180:
print('File shape is not (n, 180) but ', in_shape)
X = to_time_series_dataset(X)
X = X.reshape(in_shape[0], -1)
scaler = TimeSeriesScalerMeanVariance()
X = scaler.fit_transform(X)
return X.reshape(in_shape)
def apple_csv_to_data(file_content):
# extract sampling rate
for il,line in enumerate(file_content.decode('utf-8').splitlines()):
if line.startswith("Sample Rate"):
# Extract the sample rate
sample_rate = int(line.split(",")[1].split()[0]) # Split and get the numerical part
print(f"Sample Rate: {sample_rate}")
break
if il > 30:
print("Could not find sample rate in first 30 lines")
return None, None
X = pd.read_csv(StringIO(file_content.decode('utf-8')), skiprows=14, header=None)
return X, sample_rate
def apple_trim_join(X, sample_rate=512, ns=2):
# There should be a less horrible way of doing this
# Ignore first two and last two seconds, that tend to be noisy --> 26 seconds ecg
X[1] = X[1].fillna(0)
X = X[0] + X[1] / (10 ** (X[1].astype(str).str.len() - 2)) # Ignoring the trailing ".0"
print(f"Ignoring first and last {ns} seconds")
X = X[ns*sample_rate:-ns*sample_rate].to_frame().T
X = X.iloc[0].to_numpy()
return X
def apple_extract_beats(X, sample_rate=512):
X = apple_trim_join(X, sample_rate=sample_rate, ns=3)
# Scale and remove nans (should not happen anymore)
X = preprocessing.scale(X[~np.isnan(X)])
# I tried to hack the detection to make it learn peaks and
# not go with default, but it doesn't work
# I have tried:
# - Hardwiring n_calib_beats (not possible from user side)
# to a lower number (5, 3).
# - Setting qrs_width to lower and higher values
# - Relax the correlation requirement to Rikers wavelet
# Maybe explore correlation with more robust wavelets
# wavelet = pywt.Wavelet('db4')
# (lib/python3.10/site-packages/wfdb/processing/qrs.py)
# Conf = processing.XQRS.Conf(qrs_width=0.1)
# qrs = processing.XQRS(sig = X,fs = sample_rate, conf=Conf)
# wfdb library doesn't allow to set n_calib_beats
qrs = processing.XQRS(sig = X,fs = sample_rate)
qrs.detect()
peaks = qrs.qrs_inds
print("Number of beats detected: ", len(peaks))
target_length = 180
beats = np.zeros((len(peaks), target_length))
for i, peak in enumerate(peaks[1:-1]):
rr_interval = peaks[i + 1] - peaks[i] # Distance to the next peak
window_size = int(rr_interval * 1.2) # Extend by 20% to capture full P-QRS-T cycle
# Define the dynamic window around the R-peak
start = max(0, peak - window_size // 2)
end = min(len(X), peak + window_size // 2)
beat = resample(X[start:end], target_length)
beats[i] = beat
return beats
def save_beats_csv(beats, filepath_csv):
pd.DataFrame(beats).to_csv(filepath_csv, index=False)
def label_decoding(values, path):
with open(path, "rb") as f:
mapping = pickle.load(f)
inverse_mapping = {v: k for k, v in mapping.items()}
# return inverse_mapping[values]
return [inverse_mapping[value] for value in values]