e-commerce / backend /server.py
VincentA2K's picture
Product Recommendation RestAPI
480e694
import os
import warnings
import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# Suppress TensorFlow warnings
warnings.filterwarnings('ignore', category=UserWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from pydantic import BaseModel
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from backend.models.schemas import RecommendationRequest, ProductRecommendation, RecommendationResponse
# Initialize FastAPI app with metadata
app = FastAPI(
title="Product Recommendation API",
description="API for getting product recommendations based on user queries",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load embeddings from the correct path
models_dir = os.path.join(os.path.dirname(__file__), "models")
embeddings_path = os.path.join(models_dir, "product_embeddings.pkl")
# Ensure models directory exists
os.makedirs(models_dir, exist_ok=True)
try:
# Load embeddings
if os.path.exists(embeddings_path):
with open(embeddings_path, "rb") as f:
data = pickle.load(f)
else:
raise FileNotFoundError("Product embeddings file not found")
# Load model
model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
print(f"Error initializing server: {e}")
raise
@app.get("/")
def read_root():
return {
"message": "Product Recommendation API",
"status": "active",
"total_products": len(data["product_ids"]) if "product_ids" in data else 0
}
@app.post("/recommend", response_model=RecommendationResponse)
def recommend_products(request: RecommendationRequest):
"""Get product recommendations based on user search"""
try:
query_embedding = model.encode(request.query).reshape(1, -1)
# Compute similarity scores
similarities = cosine_similarity(query_embedding, data["embeddings"])
top_indices = np.argsort(similarities[0])[-request.top_n:][::-1]
recommendations = []
for i in top_indices:
recommendations.append(ProductRecommendation(
product_id=data["product_ids"][i],
product_name=data["product_names"][i],
description=data["descriptions"][i],
brand=data["brands"][i],
price=float(data["prices"][i]),
categories=eval(data["categories"][i]) if isinstance(data["categories"][i], str) else data["categories"][i],
score=float(similarities[0][i])
))
return RecommendationResponse(recommendations=recommendations)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
try:
uvicorn.run(
"backend.server:app",
host="0.0.0.0",
port=8000,
reload=True,
access_log=True
)
except Exception as e:
print(f"Error starting server: {e}")