GitHub Actions commited on
Commit
d769fbc
·
1 Parent(s): 68eda41

Sync API from main repo

Browse files
Files changed (10) hide show
  1. .gitattributes +0 -35
  2. Dockerfile +17 -0
  3. README.md +0 -14
  4. __init__.py +0 -0
  5. fast.py +55 -0
  6. params.py +10 -0
  7. preproc.py +21 -0
  8. requirements.txt +10 -0
  9. utils.py +40 -0
  10. wrappers.py +51 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a lightweight Python image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the API code and dependencies
8
+ COPY . /app
9
+
10
+ # Install Python dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Expose the API port
14
+ EXPOSE 8000
15
+
16
+ # Run the FastAPI server with Uvicorn
17
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md DELETED
@@ -1,14 +0,0 @@
1
- ---
2
- title: Hadt Api
3
- emoji: 👀
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: API for Heart Arrhythmia Detection Tools
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
-
14
- Dummy change
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py ADDED
File without changes
fast.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ from utils import load_model_by_type, encoder_from_model
5
+ from preproc import label_decoding
6
+ import pandas as pd
7
+ from io import StringIO
8
+ from pathlib import Path
9
+
10
+ # Get the absolute path to the package directory
11
+ PACKAGE_ROOT = Path(__file__).parent.parent.parent
12
+ MODEL_DIR = PACKAGE_ROOT / "models"
13
+
14
+ app = FastAPI()
15
+
16
+ # Use absolute paths with Path objects
17
+ model_cache = {}
18
+ encoder_cache = {}
19
+ HF_REPO_ID = "your-username/your-model-repo"
20
+
21
+ app.state.model = None # Initialize as None, load on first request
22
+
23
+ @app.get("/")
24
+ def root():
25
+ return dict(greeting="Hello")
26
+
27
+ @app.post("/predict")
28
+ async def predict(model_name: str, filepath_csv: UploadFile = File(...)):
29
+ # Load model if not already loaded
30
+ model_path = MODEL_DIR / f"{model_name}"
31
+ encoder_name = encoder_from_model(model_name)
32
+ encoder_path = MODEL_DIR / encoder_name
33
+
34
+ # if model in model_path, load it, otherwise download it from HF
35
+ if model_name not in model_cache:
36
+ try:
37
+ if not model_path.exists():
38
+ model_path = hf_hub_download(repo_id=model_name, filename=f"{model_name}")
39
+ encoder_path = hf_hub_download(repo_id=model_name, filename=f"{encoder_name}")
40
+ model_cache[model_name] = load_model_by_type(model_path)
41
+ encoder_cache[model_name] = encoder_path
42
+ except Exception as e:
43
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
44
+ model = app.state.model = model_cache[model_name]
45
+
46
+ # Read the uploaded CSV file
47
+ file_content = await filepath_csv.read()
48
+ X = pd.read_csv(StringIO(file_content.decode('utf-8')), header=None).T
49
+ y_pred = model.predict_with_pipeline(X)
50
+
51
+ # Decode prediction using absolute path
52
+
53
+ y_pred = label_decoding(value=y_pred[0], path=encoder_cache[model_name])
54
+
55
+ return {"prediction": y_pred}
params.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # GCP Project
4
+ GCP_PROJECT=os.environ.get("GCP_PROJECT")
5
+
6
+ # Cloud Storage
7
+ GOOGLE_APPLICATION_CREDENTIALS=os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
8
+ BUCKET_NAME=os.environ.get('BUCKET_NAME')
9
+ BUCKET_NAME_MODELS=os.environ.get('BUCKET_NAME_MODELS')
10
+ LOCAL_REGISTRY_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)))
preproc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tslearn.utils import to_time_series_dataset
2
+ from tslearn.preprocessing import TimeSeriesScalerMeanVariance
3
+ import pickle
4
+
5
+ def preproc_single(X):
6
+ # to be called in inference/api
7
+ in_shape = X.shape
8
+ if X.shape != (1, 180):
9
+ print('File shape is not (1, 180) but ', in_shape)
10
+
11
+ X = to_time_series_dataset(X)
12
+ X = X.reshape(in_shape[0], -1)
13
+ scaler = TimeSeriesScalerMeanVariance()
14
+ X = scaler.fit_transform(X)
15
+ return X.reshape(in_shape)
16
+
17
+ def label_decoding(value, path):
18
+ with open(path, "rb") as f:
19
+ mapping = pickle.load(f)
20
+ inverse_mapping = {v: k for k, v in mapping.items()}
21
+ return inverse_mapping[value]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ joblib
4
+ huggingface-hub
5
+ pandas==2.2.3
6
+ numpy==1.26.4
7
+ scikit-learn==1.2.2
8
+ tslearn
9
+ pickle
10
+ tensorflow
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wrappers import LSTMWrapper, XGBWrapper, CNNWrapper
2
+ import joblib
3
+ from tensorflow.keras.models import load_model
4
+
5
+
6
+ def load_model_by_type(model_path):
7
+ if model_path.suffix == '.h5':
8
+ if 'lstm_multi' in str(model_path):
9
+ return LSTMWrapper(load_model(model_path))
10
+ elif 'cnn_multi' in str(model_path):
11
+ return CNNWrapper(load_model(model_path))
12
+ else:
13
+ raise ValueError("Unsupported model type")
14
+ elif model_path.suffix == '.pkl':
15
+ return XGBWrapper(joblib.load(model_path))
16
+ else:
17
+ raise ValueError("Unsupported model type")
18
+
19
+ def encoder_from_model(model_name):
20
+ if model_name == "cnn_multi_model.h5":
21
+ return "cnn_multi_label_encoding.pkl"
22
+ elif model_name == "lstm_multi_model.h5":
23
+ return "lstm_multi_label_encoding.pkl"
24
+ elif model_name == "pca_xgboost_multi_model.pkl":
25
+ return "pca_xgboost_multi_label_encoding.pkl"
26
+ elif model_name == "cnn_binary_model.h5":
27
+ return "cnn_binary_label_encoding.pkl"
28
+ elif model_name == "lstm_binary_model.h5":
29
+ return "lstm_binary_label_encoding.pkl"
30
+ elif model_name == "pca_xgboost_binary_model.pkl":
31
+ return "pca_xgboost_binary_label_encoding.pkl"
32
+ else:
33
+ raise ValueError("Unsupported model name")
34
+
35
+
36
+ if __name__ == "__main__":
37
+ from pathlib import Path
38
+ PACKAGE_ROOT = Path(__file__).parent.parent.parent
39
+ MODEL_PATH = PACKAGE_ROOT / "models" / "lstm_multi_model.h5"
40
+ load_model_by_type(MODEL_PATH)
wrappers.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from preproc import preproc_single
3
+
4
+ class BaseModelWrapper:
5
+ def __init__(self, model):
6
+ self.model = model
7
+
8
+ def preprocess(self, data):
9
+ """Default preprocessing (can be overridden)."""
10
+ return preproc_single(data)
11
+
12
+ def predict(self, data):
13
+ """Call the model's prediction."""
14
+ raise NotImplementedError("Subclasses must implement predict()")
15
+
16
+ def postprocess(self, prediction):
17
+ """Default postprocessing (can be overridden)."""
18
+ return prediction
19
+
20
+ def predict_with_pipeline(self, data):
21
+ """Unified prediction pipeline."""
22
+ processed_data = self.preprocess(data)
23
+ raw_prediction = self.predict(processed_data)
24
+ final_output = self.postprocess(raw_prediction)
25
+ return final_output
26
+
27
+
28
+ class LSTMWrapper(BaseModelWrapper):
29
+ def preprocess(self, data):
30
+ # LSTM requires additional dimension expansion
31
+ data = preproc_single(data)
32
+ return np.expand_dims(data, axis=1) # Add time-step dimension
33
+
34
+ def predict(self, data):
35
+ return self.model.predict(data)
36
+
37
+ def postprocess(self, prediction):
38
+ # Assume the output is a probability vector; apply argmax
39
+ return np.argmax(prediction, axis=1).tolist()
40
+
41
+
42
+ class XGBWrapper(BaseModelWrapper):
43
+ def predict(self, data):
44
+ return self.model.predict(data).tolist()
45
+
46
+ class CNNWrapper(BaseModelWrapper):
47
+ def predict(self, data):
48
+ return self.model.predict(data)
49
+
50
+ def postprocess(self, prediction):
51
+ return np.argmax(prediction, axis=1).tolist()