GitHub Actions commited on
Commit
6c0d821
·
1 Parent(s): c0713b1

Sync API from main repo

Browse files
Files changed (4) hide show
  1. fast.py +18 -7
  2. preproc.py +77 -5
  3. requirements.txt +1 -0
  4. wrappers.py +3 -3
fast.py CHANGED
@@ -37,8 +37,7 @@ app.state.model = None # Initialize as None, load on first request
37
  def root():
38
  return dict(greeting="Hello")
39
 
40
- @app.post("/predict")
41
- async def predict(model_name: str, filepath_csv: UploadFile = File(...)):
42
  # Load model if not already loaded
43
  model_path = MODEL_DIR / f"{model_name}"
44
  encoder_name = encoder_from_model(model_name)
@@ -46,20 +45,23 @@ async def predict(model_name: str, filepath_csv: UploadFile = File(...)):
46
 
47
  # if model in model_path, load it, otherwise download it from HF
48
  if model_name not in model_cache:
49
- # print("model_name", model_name)
50
- # print("model_path", model_path)
51
  try:
52
  if not model_path.exists():
53
  # Convert downloaded paths to Path objects
54
  model_path = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=f"{model_name}", cache_dir=CACHE_DIR))
55
  encoder_path = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=f"{encoder_name}", cache_dir=CACHE_DIR))
56
- # print("model_path", model_path)
57
  model_cache[model_name] = load_model_by_type(model_path) # Ensure string path for loading
58
  encoder_cache[model_name] = encoder_path
59
  except Exception as e:
60
  print(f"Error loading model: {str(e)}") # Add debug print
61
  raise HTTPException(status_code=404, detail=f"Model {model_name} not found: {str(e)}")
62
- model = app.state.model = model_cache[model_name]
 
 
 
 
 
 
63
 
64
  # Read the uploaded CSV file
65
  file_content = await filepath_csv.read()
@@ -68,6 +70,15 @@ async def predict(model_name: str, filepath_csv: UploadFile = File(...)):
68
 
69
  # Decode prediction using absolute path
70
 
71
- y_pred = label_decoding(value=y_pred[0], path=encoder_cache[model_name])
72
 
73
  return {"prediction": y_pred}
 
 
 
 
 
 
 
 
 
 
37
  def root():
38
  return dict(greeting="Hello")
39
 
40
+ def model_loader(model_name):
 
41
  # Load model if not already loaded
42
  model_path = MODEL_DIR / f"{model_name}"
43
  encoder_name = encoder_from_model(model_name)
 
45
 
46
  # if model in model_path, load it, otherwise download it from HF
47
  if model_name not in model_cache:
 
 
48
  try:
49
  if not model_path.exists():
50
  # Convert downloaded paths to Path objects
51
  model_path = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=f"{model_name}", cache_dir=CACHE_DIR))
52
  encoder_path = Path(hf_hub_download(repo_id=HF_REPO_ID, filename=f"{encoder_name}", cache_dir=CACHE_DIR))
 
53
  model_cache[model_name] = load_model_by_type(model_path) # Ensure string path for loading
54
  encoder_cache[model_name] = encoder_path
55
  except Exception as e:
56
  print(f"Error loading model: {str(e)}") # Add debug print
57
  raise HTTPException(status_code=404, detail=f"Model {model_name} not found: {str(e)}")
58
+ return model_cache[model_name]
59
+
60
+
61
+ @app.post("/predict")
62
+ async def predict(model_name: str, filepath_csv: UploadFile = File(...)):
63
+
64
+ model = app.state.model = model_loader(model_name)
65
 
66
  # Read the uploaded CSV file
67
  file_content = await filepath_csv.read()
 
70
 
71
  # Decode prediction using absolute path
72
 
73
+ y_pred = label_decoding(values=y_pred, path=encoder_cache[model_name])
74
 
75
  return {"prediction": y_pred}
76
+
77
+
78
+ # @app.post("/predict_multibeats")
79
+ # async def predict_multibeats(model_name: str, filepath_csv: UploadFile = File(...)):
80
+ # # Read the uploaded CSV file
81
+ # file_content = await filepath_csv.read()
82
+ # X = pd.read_csv(StringIO(file_content.decode('utf-8')))
83
+ # y_pred = model.predict_with_pipeline(X)
84
+ # return {"prediction": y_pred}
preproc.py CHANGED
@@ -1,12 +1,18 @@
1
  from tslearn.utils import to_time_series_dataset
2
  from tslearn.preprocessing import TimeSeriesScalerMeanVariance
3
  import pickle
 
 
 
4
 
5
- def preproc_single(X):
 
 
 
6
  # to be called in inference/api
7
  in_shape = X.shape
8
- if X.shape != (1, 180):
9
- print('File shape is not (1, 180) but ', in_shape)
10
 
11
  X = to_time_series_dataset(X)
12
  X = X.reshape(in_shape[0], -1)
@@ -14,8 +20,74 @@ def preproc_single(X):
14
  X = scaler.fit_transform(X)
15
  return X.reshape(in_shape)
16
 
17
- def label_decoding(value, path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  with open(path, "rb") as f:
19
  mapping = pickle.load(f)
20
  inverse_mapping = {v: k for k, v in mapping.items()}
21
- return inverse_mapping[value]
 
 
1
  from tslearn.utils import to_time_series_dataset
2
  from tslearn.preprocessing import TimeSeriesScalerMeanVariance
3
  import pickle
4
+ from wfdb import rdrecord, rdann, processing
5
+ from sklearn import preprocessing
6
+ from scipy.signal import resample
7
 
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ def preproc(X):
12
  # to be called in inference/api
13
  in_shape = X.shape
14
+ if X.shape[1] != 180:
15
+ print('File shape is not (n, 180) but ', in_shape)
16
 
17
  X = to_time_series_dataset(X)
18
  X = X.reshape(in_shape[0], -1)
 
20
  X = scaler.fit_transform(X)
21
  return X.reshape(in_shape)
22
 
23
+ def apple_csv_to_data(filepath_csv):
24
+ # extract sampling rate
25
+ with open(filepath_csv, 'r') as file:
26
+ for il,line in enumerate(file):
27
+ if line.startswith("Sample Rate"):
28
+ # Extract the sample rate
29
+ sample_rate = int(line.split(",")[1].split()[0]) # Split and get the numerical part
30
+ print(f"Sample Rate: {sample_rate}")
31
+ break
32
+ if il > 30:
33
+ print("Could not find sample rate in first 30 lines")
34
+ return None, None
35
+ X = pd.read_csv(filepath_csv, skiprows=14, header=None)
36
+ return X, sample_rate
37
+
38
+ def apple_trim_join(X, sample_rate=512, ns=2):
39
+ # There should be a less horrible way of doing this
40
+ # Ignore first two and last two seconds, that tend to be noisy --> 26 seconds ecg
41
+ X[1] = X[1].fillna(0)
42
+ X = X[0] + X[1] / (10 ** (X[1].astype(str).str.len() - 2)) # Ignoring the trailing ".0"
43
+ print(f"Ignoring first and last {ns} seconds")
44
+ X = X[ns*sample_rate:-ns*sample_rate].to_frame().T
45
+ X = X.iloc[0].to_numpy()
46
+ return X
47
+
48
+ def apple_extract_beats(X, sample_rate=512):
49
+ X = apple_trim_join(X, sample_rate=sample_rate, ns=3)
50
+ # Scale and remove nans (should not happen anymore)
51
+ X = preprocessing.scale(X[~np.isnan(X)])
52
+
53
+ # I tried to hack the detection to make it learn peaks and
54
+ # not go with default, but it doesn't work
55
+ # I have tried:
56
+ # - Hardwiring n_calib_beats (not possible from user side)
57
+ # to a lower number (5, 3).
58
+ # - Setting qrs_width to lower and higher values
59
+ # - Relax the correlation requirement to Rikers wavelet
60
+ # Maybe explore correlation with more robust wavelets
61
+ # wavelet = pywt.Wavelet('db4')
62
+ # (lib/python3.10/site-packages/wfdb/processing/qrs.py)
63
+
64
+ # Conf = processing.XQRS.Conf(qrs_width=0.1)
65
+ # qrs = processing.XQRS(sig = X,fs = sample_rate, conf=Conf)
66
+ # wfdb library doesn't allow to set n_calib_beats
67
+
68
+ qrs = processing.XQRS(sig = X,fs = sample_rate)
69
+ qrs.detect()
70
+ peaks = qrs.qrs_inds
71
+ print("Number of beats detected: ", len(peaks))
72
+ target_length = 180
73
+ beats = np.zeros((len(peaks), target_length))
74
+
75
+ for i, peak in enumerate(peaks[1:-1]):
76
+ rr_interval = peaks[i + 1] - peaks[i] # Distance to the next peak
77
+ window_size = int(rr_interval * 1.2) # Extend by 20% to capture full P-QRS-T cycle
78
+ # Define the dynamic window around the R-peak
79
+ start = max(0, peak - window_size // 2)
80
+ end = min(len(X), peak + window_size // 2)
81
+ beat = resample(X[start:end], target_length)
82
+ beats[i] = beat
83
+ return beats
84
+
85
+ def save_beats_csv(beats, filepath_csv):
86
+ pd.DataFrame(beats).to_csv(filepath_csv, index=False)
87
+
88
+ def label_decoding(values, path):
89
  with open(path, "rb") as f:
90
  mapping = pickle.load(f)
91
  inverse_mapping = {v: k for k, v in mapping.items()}
92
+ # return inverse_mapping[values]
93
+ return [inverse_mapping[value] for value in values]
requirements.txt CHANGED
@@ -5,6 +5,7 @@ huggingface-hub
5
  pandas==2.2.3
6
  numpy==1.26.4
7
  scikit-learn==1.2.2
 
8
  tslearn
9
  tensorflow
10
  python-multipart
 
5
  pandas==2.2.3
6
  numpy==1.26.4
7
  scikit-learn==1.2.2
8
+ scipy
9
  tslearn
10
  tensorflow
11
  python-multipart
wrappers.py CHANGED
@@ -1,5 +1,5 @@
1
  import numpy as np
2
- from preproc import preproc_single
3
 
4
  class BaseModelWrapper:
5
  def __init__(self, model):
@@ -7,7 +7,7 @@ class BaseModelWrapper:
7
 
8
  def preprocess(self, data):
9
  """Default preprocessing (can be overridden)."""
10
- return preproc_single(data)
11
 
12
  def predict(self, data):
13
  """Call the model's prediction."""
@@ -28,7 +28,7 @@ class BaseModelWrapper:
28
  class LSTMWrapper(BaseModelWrapper):
29
  def preprocess(self, data):
30
  # LSTM requires additional dimension expansion
31
- data = preproc_single(data)
32
  return np.expand_dims(data, axis=1) # Add time-step dimension
33
 
34
  def predict(self, data):
 
1
  import numpy as np
2
+ from preproc import preproc
3
 
4
  class BaseModelWrapper:
5
  def __init__(self, model):
 
7
 
8
  def preprocess(self, data):
9
  """Default preprocessing (can be overridden)."""
10
+ return preproc(data)
11
 
12
  def predict(self, data):
13
  """Call the model's prediction."""
 
28
  class LSTMWrapper(BaseModelWrapper):
29
  def preprocess(self, data):
30
  # LSTM requires additional dimension expansion
31
+ data = preproc(data)
32
  return np.expand_dims(data, axis=1) # Add time-step dimension
33
 
34
  def predict(self, data):