File size: 3,774 Bytes
500c1ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from huggingface_hub import login
from fastapi import FastAPI, Depends, HTTPException
import logging
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from services.qdrant_searcher import QdrantSearcher
from services.openai_service import generate_rag_response
from utils.auth import token_required
from dotenv import load_dotenv
import os
load_dotenv() # Load environment variables from .env file
app = FastAPI()
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Ensure the cache directory exists
cache_dir = os.environ["HF_HOME"]
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Setup logging
logging.basicConfig(level=logging.INFO)
# Load Hugging Face token from environment variable
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
if huggingface_token:
login(token=huggingface_token, add_to_git_credential=True)
else:
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.")
# Initialize the Qdrant searcher
qdrant_url = os.getenv('QDRANT_URL')
access_token = os.getenv('QDRANT_ACCESS_TOKEN')
encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2', trust_remote_code=True) # Replace with your actual encoder
searcher = QdrantSearcher(encoder, qdrant_url, access_token)
# Request body models
class SearchDocumentsRequest(BaseModel):
query: str
limit: int = 3
class GenerateRAGRequest(BaseModel):
search_query: str
@app.post("/api/search-documents")
async def search_documents(
body: SearchDocumentsRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
# Check if customer_id or user_id is missing
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to search documents")
try:
collection_name = "my_embeddings"
hits, error = searcher.search_documents(collection_name, body.query, user_id, body.limit)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
return hits
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate-rag-response")
async def generate_rag_response_api(
body: GenerateRAGRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
# Check if customer_id or user_id is missing
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to generate RAG response")
try:
collection_name = "my_embeddings"
hits, error = searcher.search_documents(collection_name, body.search_query, user_id)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
response, error = generate_rag_response(hits, body.search_query)
if error:
logging.error(f"Generate RAG response error: {error}")
raise HTTPException(status_code=500, detail=error)
return {"response": response}
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000) |