Spaces:
No application file
No application file
import os | |
import warnings | |
import logging | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
# Suppress TensorFlow warnings | |
warnings.filterwarnings('ignore', category=UserWarning) | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging | |
logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
from pydantic import BaseModel | |
import pickle | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from backend.models.schemas import RecommendationRequest, ProductRecommendation, RecommendationResponse | |
# Initialize FastAPI app with metadata | |
app = FastAPI( | |
title="Product Recommendation API", | |
description="API for getting product recommendations based on user queries", | |
version="1.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load embeddings from the correct path | |
models_dir = os.path.join(os.path.dirname(__file__), "models") | |
embeddings_path = os.path.join(models_dir, "product_embeddings.pkl") | |
# Ensure models directory exists | |
os.makedirs(models_dir, exist_ok=True) | |
try: | |
# Load embeddings | |
if os.path.exists(embeddings_path): | |
with open(embeddings_path, "rb") as f: | |
data = pickle.load(f) | |
else: | |
raise FileNotFoundError("Product embeddings file not found") | |
# Load model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
except Exception as e: | |
print(f"Error initializing server: {e}") | |
raise | |
def read_root(): | |
return { | |
"message": "Product Recommendation API", | |
"status": "active", | |
"total_products": len(data["product_ids"]) if "product_ids" in data else 0 | |
} | |
def recommend_products(request: RecommendationRequest): | |
"""Get product recommendations based on user search""" | |
try: | |
query_embedding = model.encode(request.query).reshape(1, -1) | |
# Compute similarity scores | |
similarities = cosine_similarity(query_embedding, data["embeddings"]) | |
top_indices = np.argsort(similarities[0])[-request.top_n:][::-1] | |
recommendations = [] | |
for i in top_indices: | |
recommendations.append(ProductRecommendation( | |
product_id=data["product_ids"][i], | |
product_name=data["product_names"][i], | |
description=data["descriptions"][i], | |
brand=data["brands"][i], | |
price=float(data["prices"][i]), | |
categories=eval(data["categories"][i]) if isinstance(data["categories"][i], str) else data["categories"][i], | |
score=float(similarities[0][i]) | |
)) | |
return RecommendationResponse(recommendations=recommendations) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
try: | |
uvicorn.run( | |
"backend.server:app", | |
host="0.0.0.0", | |
port=8000, | |
reload=True, | |
access_log=True | |
) | |
except Exception as e: | |
print(f"Error starting server: {e}") | |