Update product_return_prediction/api.py
Browse files
product_return_prediction/api.py
CHANGED
@@ -8,6 +8,7 @@ import pickle
|
|
8 |
from pathlib import Path
|
9 |
from product_return_prediction.dataset import prepare_inventory, scale_data_with_trained_scaler
|
10 |
from product_return_prediction.config import MODELS_DIR, EXTERNAL_DATA_DIR
|
|
|
11 |
|
12 |
app = FastAPI(
|
13 |
title="Product Return Prediction API",
|
@@ -48,6 +49,18 @@ def load_json(file_path: Path) -> dict:
|
|
48 |
raise HTTPException(status_code=500, detail=f"Error reading JSON file {file_path}: {e}")
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def filter_inventory_by_combinations(inventory: pd.DataFrame, models: list, fabrics: list, colours: list) -> pd.DataFrame:
|
52 |
"""Filter inventory based on the product combinations."""
|
53 |
filtered_inventory = pd.DataFrame()
|
@@ -113,8 +126,8 @@ async def root():
|
|
113 |
@app.post("/predict/")
|
114 |
async def predict(products: ProductRequest):
|
115 |
inventory_path: Path = EXTERNAL_DATA_DIR / "inventory.tsv"
|
116 |
-
|
117 |
-
|
118 |
|
119 |
hf_token = os.getenv("inventory_data")
|
120 |
dataset = load_dataset("molinari135/armani-inventory", token=hf_token, data_files="inventory.tsv")
|
@@ -131,6 +144,15 @@ async def predict(products: ProductRequest):
|
|
131 |
filtered_inventory, products.total_customer_purchases, products.total_customer_returns
|
132 |
)
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
model = load_model(model_path)
|
135 |
|
136 |
scaled_inventory = apply_scaling(prepared_inventory, scaler_file)
|
|
|
8 |
from pathlib import Path
|
9 |
from product_return_prediction.dataset import prepare_inventory, scale_data_with_trained_scaler
|
10 |
from product_return_prediction.config import MODELS_DIR, EXTERNAL_DATA_DIR
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
app = FastAPI(
|
14 |
title="Product Return Prediction API",
|
|
|
49 |
raise HTTPException(status_code=500, detail=f"Error reading JSON file {file_path}: {e}")
|
50 |
|
51 |
|
52 |
+
def download_file(url: str, local_path: Path, headers: dict = None):
|
53 |
+
"""Download a file from a URL and save it locally."""
|
54 |
+
try:
|
55 |
+
response = requests.get(url, headers=headers)
|
56 |
+
response.raise_for_status() # Raise an exception for HTTP errors
|
57 |
+
with open(local_path, 'wb') as f:
|
58 |
+
f.write(response.content)
|
59 |
+
return local_path
|
60 |
+
except Exception as e:
|
61 |
+
raise HTTPException(status_code=500, detail=f"Error downloading file from {url}: {e}")
|
62 |
+
|
63 |
+
|
64 |
def filter_inventory_by_combinations(inventory: pd.DataFrame, models: list, fabrics: list, colours: list) -> pd.DataFrame:
|
65 |
"""Filter inventory based on the product combinations."""
|
66 |
filtered_inventory = pd.DataFrame()
|
|
|
126 |
@app.post("/predict/")
|
127 |
async def predict(products: ProductRequest):
|
128 |
inventory_path: Path = EXTERNAL_DATA_DIR / "inventory.tsv"
|
129 |
+
model_name: str = "svm.pkl"
|
130 |
+
scaler_name: str = "scaler.pkl"
|
131 |
|
132 |
hf_token = os.getenv("inventory_data")
|
133 |
dataset = load_dataset("molinari135/armani-inventory", token=hf_token, data_files="inventory.tsv")
|
|
|
144 |
filtered_inventory, products.total_customer_purchases, products.total_customer_returns
|
145 |
)
|
146 |
|
147 |
+
models_uri = "https://huggingface.co/molinari135/se4ai-models/resolve/main/"
|
148 |
+
model_path = MODELS_DIR / model_name
|
149 |
+
scaler_path = MODELS_DIR / scaler_name
|
150 |
+
|
151 |
+
headers = f"Authorization: Bearer {hf_token}"
|
152 |
+
|
153 |
+
download_file(f{models_uri}{model_name}, model_path, headers)
|
154 |
+
download_file(f{models_uri}{scaler_name}, scaler_path, headers)
|
155 |
+
|
156 |
model = load_model(model_path)
|
157 |
|
158 |
scaled_inventory = apply_scaling(prepared_inventory, scaler_file)
|