molinari135 commited on
Commit
f8dac16
·
verified ·
1 Parent(s): 36ee4da

Update product_return_prediction/api.py

Browse files
Files changed (1) hide show
  1. product_return_prediction/api.py +24 -2
product_return_prediction/api.py CHANGED
@@ -8,6 +8,7 @@ import pickle
8
  from pathlib import Path
9
  from product_return_prediction.dataset import prepare_inventory, scale_data_with_trained_scaler
10
  from product_return_prediction.config import MODELS_DIR, EXTERNAL_DATA_DIR
 
11
 
12
  app = FastAPI(
13
  title="Product Return Prediction API",
@@ -48,6 +49,18 @@ def load_json(file_path: Path) -> dict:
48
  raise HTTPException(status_code=500, detail=f"Error reading JSON file {file_path}: {e}")
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def filter_inventory_by_combinations(inventory: pd.DataFrame, models: list, fabrics: list, colours: list) -> pd.DataFrame:
52
  """Filter inventory based on the product combinations."""
53
  filtered_inventory = pd.DataFrame()
@@ -113,8 +126,8 @@ async def root():
113
  @app.post("/predict/")
114
  async def predict(products: ProductRequest):
115
  inventory_path: Path = EXTERNAL_DATA_DIR / "inventory.tsv"
116
- model_path: Path = MODELS_DIR / "svm.pkl"
117
- scaler_file: Path = MODELS_DIR / "scaler.pkl"
118
 
119
  hf_token = os.getenv("inventory_data")
120
  dataset = load_dataset("molinari135/armani-inventory", token=hf_token, data_files="inventory.tsv")
@@ -131,6 +144,15 @@ async def predict(products: ProductRequest):
131
  filtered_inventory, products.total_customer_purchases, products.total_customer_returns
132
  )
133
 
 
 
 
 
 
 
 
 
 
134
  model = load_model(model_path)
135
 
136
  scaled_inventory = apply_scaling(prepared_inventory, scaler_file)
 
8
  from pathlib import Path
9
  from product_return_prediction.dataset import prepare_inventory, scale_data_with_trained_scaler
10
  from product_return_prediction.config import MODELS_DIR, EXTERNAL_DATA_DIR
11
+ from huggingface_hub import hf_hub_download
12
 
13
  app = FastAPI(
14
  title="Product Return Prediction API",
 
49
  raise HTTPException(status_code=500, detail=f"Error reading JSON file {file_path}: {e}")
50
 
51
 
52
+ def download_file(url: str, local_path: Path, headers: dict = None):
53
+ """Download a file from a URL and save it locally."""
54
+ try:
55
+ response = requests.get(url, headers=headers)
56
+ response.raise_for_status() # Raise an exception for HTTP errors
57
+ with open(local_path, 'wb') as f:
58
+ f.write(response.content)
59
+ return local_path
60
+ except Exception as e:
61
+ raise HTTPException(status_code=500, detail=f"Error downloading file from {url}: {e}")
62
+
63
+
64
  def filter_inventory_by_combinations(inventory: pd.DataFrame, models: list, fabrics: list, colours: list) -> pd.DataFrame:
65
  """Filter inventory based on the product combinations."""
66
  filtered_inventory = pd.DataFrame()
 
126
  @app.post("/predict/")
127
  async def predict(products: ProductRequest):
128
  inventory_path: Path = EXTERNAL_DATA_DIR / "inventory.tsv"
129
+ model_name: str = "svm.pkl"
130
+ scaler_name: str = "scaler.pkl"
131
 
132
  hf_token = os.getenv("inventory_data")
133
  dataset = load_dataset("molinari135/armani-inventory", token=hf_token, data_files="inventory.tsv")
 
144
  filtered_inventory, products.total_customer_purchases, products.total_customer_returns
145
  )
146
 
147
+ models_uri = "https://huggingface.co/molinari135/se4ai-models/resolve/main/"
148
+ model_path = MODELS_DIR / model_name
149
+ scaler_path = MODELS_DIR / scaler_name
150
+
151
+ headers = f"Authorization: Bearer {hf_token}"
152
+
153
+ download_file(f{models_uri}{model_name}, model_path, headers)
154
+ download_file(f{models_uri}{scaler_name}, scaler_path, headers)
155
+
156
  model = load_model(model_path)
157
 
158
  scaled_inventory = apply_scaling(prepared_inventory, scaler_file)