Spaces:
Sleeping
Sleeping
from tslearn.utils import to_time_series_dataset | |
from tslearn.preprocessing import TimeSeriesScalerMeanVariance | |
import pickle | |
from wfdb import processing | |
from sklearn import preprocessing | |
from scipy.signal import resample | |
from io import StringIO | |
import numpy as np | |
import pandas as pd | |
def preproc(X): | |
# to be called in inference/api | |
in_shape = X.shape | |
if X.shape[1] != 180: | |
print('File shape is not (n, 180) but ', in_shape) | |
X = to_time_series_dataset(X) | |
X = X.reshape(in_shape[0], -1) | |
scaler = TimeSeriesScalerMeanVariance() | |
X = scaler.fit_transform(X) | |
return X.reshape(in_shape) | |
def apple_csv_to_data(file_content): | |
# extract sampling rate | |
for il,line in enumerate(file_content.decode('utf-8').splitlines()): | |
if line.startswith("Sample Rate"): | |
# Extract the sample rate | |
sample_rate = int(line.split(",")[1].split()[0]) # Split and get the numerical part | |
print(f"Sample Rate: {sample_rate}") | |
break | |
if il > 30: | |
print("Could not find sample rate in first 30 lines") | |
return None, None | |
X = pd.read_csv(StringIO(file_content.decode('utf-8')), skiprows=14, header=None) | |
return X, sample_rate | |
def apple_trim_join(X, sample_rate=512, ns=2): | |
# There should be a less horrible way of doing this | |
# Ignore first two and last two seconds, that tend to be noisy --> 26 seconds ecg | |
X[1] = X[1].fillna(0) | |
X = X[0] + X[1] / (10 ** (X[1].astype(str).str.len() - 2)) # Ignoring the trailing ".0" | |
print(f"Ignoring first and last {ns} seconds") | |
X = X[ns*sample_rate:-ns*sample_rate].to_frame().T | |
X = X.iloc[0].to_numpy() | |
return X | |
def apple_extract_beats(X, sample_rate=512): | |
X = apple_trim_join(X, sample_rate=sample_rate, ns=3) | |
# Scale and remove nans (should not happen anymore) | |
X = preprocessing.scale(X[~np.isnan(X)]) | |
# I tried to hack the detection to make it learn peaks and | |
# not go with default, but it doesn't work | |
# I have tried: | |
# - Hardwiring n_calib_beats (not possible from user side) | |
# to a lower number (5, 3). | |
# - Setting qrs_width to lower and higher values | |
# - Relax the correlation requirement to Rikers wavelet | |
# Maybe explore correlation with more robust wavelets | |
# wavelet = pywt.Wavelet('db4') | |
# (lib/python3.10/site-packages/wfdb/processing/qrs.py) | |
# Conf = processing.XQRS.Conf(qrs_width=0.1) | |
# qrs = processing.XQRS(sig = X,fs = sample_rate, conf=Conf) | |
# wfdb library doesn't allow to set n_calib_beats | |
qrs = processing.XQRS(sig = X,fs = sample_rate) | |
qrs.detect() | |
peaks = qrs.qrs_inds | |
print("Number of beats detected: ", len(peaks)) | |
target_length = 180 | |
beats = np.zeros((len(peaks), target_length)) | |
for i, peak in enumerate(peaks[1:-1]): | |
rr_interval = peaks[i + 1] - peaks[i] # Distance to the next peak | |
window_size = int(rr_interval * 1.2) # Extend by 20% to capture full P-QRS-T cycle | |
# Define the dynamic window around the R-peak | |
start = max(0, peak - window_size // 2) | |
end = min(len(X), peak + window_size // 2) | |
beat = resample(X[start:end], target_length) | |
beats[i] = beat | |
return beats | |
def save_beats_csv(beats, filepath_csv): | |
pd.DataFrame(beats).to_csv(filepath_csv, index=False) | |
def label_decoding(values, path): | |
with open(path, "rb") as f: | |
mapping = pickle.load(f) | |
inverse_mapping = {v: k for k, v in mapping.items()} | |
# return inverse_mapping[values] | |
return [inverse_mapping[value] for value in values] |