GitHub Actions commited on
Commit
947b4e0
·
1 Parent(s): 03cf173

Sync API from main repo

Browse files
Files changed (2) hide show
  1. fast.py +17 -1
  2. preproc.py +15 -13
fast.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from huggingface_hub import hf_hub_download
3
 
4
  from utils import load_model_by_type, encoder_from_model
5
- from preproc import label_decoding
6
  import pandas as pd
7
  from io import StringIO
8
  from pathlib import Path
@@ -74,6 +74,22 @@ async def predict(model_name: str, filepath_csv: UploadFile = File(...)):
74
 
75
  return {"prediction": y_pred}
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # @app.post("/predict_multibeats")
79
  # async def predict_multibeats(model_name: str, filepath_csv: UploadFile = File(...)):
 
2
  from huggingface_hub import hf_hub_download
3
 
4
  from utils import load_model_by_type, encoder_from_model
5
+ from preproc import label_decoding, apple_csv_to_data, apple_extract_beats
6
  import pandas as pd
7
  from io import StringIO
8
  from pathlib import Path
 
74
 
75
  return {"prediction": y_pred}
76
 
77
+ @app.post("/predict_multibeats")
78
+ async def predict_multibeats(model_name: str, filepath_csv: UploadFile = File(...)):
79
+ model = app.state.model = model_loader(model_name)
80
+
81
+ # Read the uploaded CSV file
82
+ file_content = await filepath_csv.read()
83
+ # X = pd.read_csv(StringIO(file_content.decode('utf-8')))
84
+ X, sample_rate = apple_csv_to_data(file_content)
85
+ beats = apple_extract_beats(X, sample_rate)
86
+ y_pred = model.predict_with_pipeline(beats)
87
+
88
+ # Decode prediction using absolute path
89
+
90
+ y_pred = label_decoding(values=y_pred, path=encoder_cache[model_name])
91
+
92
+ return {"prediction": y_pred}
93
 
94
  # @app.post("/predict_multibeats")
95
  # async def predict_multibeats(model_name: str, filepath_csv: UploadFile = File(...)):
preproc.py CHANGED
@@ -1,9 +1,11 @@
1
  from tslearn.utils import to_time_series_dataset
2
  from tslearn.preprocessing import TimeSeriesScalerMeanVariance
3
  import pickle
4
- from wfdb import rdrecord, rdann, processing
5
  from sklearn import preprocessing
6
  from scipy.signal import resample
 
 
7
 
8
  import numpy as np
9
  import pandas as pd
@@ -20,19 +22,19 @@ def preproc(X):
20
  X = scaler.fit_transform(X)
21
  return X.reshape(in_shape)
22
 
23
- def apple_csv_to_data(filepath_csv):
24
  # extract sampling rate
25
- with open(filepath_csv, 'r') as file:
26
- for il,line in enumerate(file):
27
- if line.startswith("Sample Rate"):
28
- # Extract the sample rate
29
- sample_rate = int(line.split(",")[1].split()[0]) # Split and get the numerical part
30
- print(f"Sample Rate: {sample_rate}")
31
- break
32
- if il > 30:
33
- print("Could not find sample rate in first 30 lines")
34
- return None, None
35
- X = pd.read_csv(filepath_csv, skiprows=14, header=None)
36
  return X, sample_rate
37
 
38
  def apple_trim_join(X, sample_rate=512, ns=2):
 
1
  from tslearn.utils import to_time_series_dataset
2
  from tslearn.preprocessing import TimeSeriesScalerMeanVariance
3
  import pickle
4
+ from wfdb import processing
5
  from sklearn import preprocessing
6
  from scipy.signal import resample
7
+ from io import StringIO
8
+
9
 
10
  import numpy as np
11
  import pandas as pd
 
22
  X = scaler.fit_transform(X)
23
  return X.reshape(in_shape)
24
 
25
+ def apple_csv_to_data(file_content):
26
  # extract sampling rate
27
+
28
+ for il,line in enumerate(file_content.decode('utf-8').splitlines()):
29
+ if line.startswith("Sample Rate"):
30
+ # Extract the sample rate
31
+ sample_rate = int(line.split(",")[1].split()[0]) # Split and get the numerical part
32
+ print(f"Sample Rate: {sample_rate}")
33
+ break
34
+ if il > 30:
35
+ print("Could not find sample rate in first 30 lines")
36
+ return None, None
37
+ X = pd.read_csv(StringIO(file_content.decode('utf-8')), skiprows=14, header=None)
38
  return X, sample_rate
39
 
40
  def apple_trim_join(X, sample_rate=512, ns=2):