|
from huggingface_hub import login |
|
from fastapi import FastAPI, Depends, HTTPException |
|
import logging |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModel |
|
from services.qdrant_searcher import QdrantSearcher |
|
from services.openai_service import generate_rag_response |
|
from utils.auth import token_required |
|
from dotenv import load_dotenv |
|
import os |
|
import torch |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface_cache" |
|
|
|
|
|
hf_home_dir = os.environ["HF_HOME"] |
|
if not os.path.exists(hf_home_dir): |
|
os.makedirs(hf_home_dir) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') |
|
if huggingface_token: |
|
try: |
|
login(token=huggingface_token, add_to_git_credential=True) |
|
logging.info("Successfully logged into Hugging Face Hub.") |
|
except Exception as e: |
|
logging.error(f"Failed to log into Hugging Face Hub: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") |
|
else: |
|
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") |
|
|
|
|
|
qdrant_url = os.getenv('QDRANT_URL') |
|
access_token = os.getenv('QDRANT_ACCESS_TOKEN') |
|
|
|
if not qdrant_url or not access_token: |
|
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") |
|
|
|
|
|
try: |
|
cache_folder = os.path.join(hf_home_dir, "transformers_cache") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) |
|
model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) |
|
|
|
logging.info("Successfully loaded the model and tokenizer with transformers.") |
|
|
|
|
|
global searcher |
|
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token) |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load the model or initialize searcher: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to load the custom model or initialize searcher.") |
|
|
|
|
|
def embed_text(text): |
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
|
outputs = model(**inputs) |
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
return embeddings.detach().numpy() |
|
|
|
|
|
class SearchDocumentsRequest(BaseModel): |
|
query: str |
|
limit: int = 3 |
|
|
|
class GenerateRAGRequest(BaseModel): |
|
search_query: str |
|
|
|
|
|
@app.post("/api/search-documents") |
|
async def search_documents( |
|
body: SearchDocumentsRequest, |
|
credentials: tuple = Depends(token_required) |
|
): |
|
customer_id, user_id = credentials |
|
|
|
if not customer_id or not user_id: |
|
logging.error("Failed to extract customer_id or user_id from the JWT token.") |
|
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") |
|
|
|
logging.info("Received request to search documents") |
|
try: |
|
logging.info("Starting document search") |
|
|
|
|
|
query_embedding = embed_text(body.query) |
|
|
|
collection_name = "my_embeddings" |
|
|
|
|
|
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit) |
|
|
|
if error: |
|
logging.error(f"Search documents error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
return hits |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/generate-rag-response") |
|
async def generate_rag_response_api( |
|
body: GenerateRAGRequest, |
|
credentials: tuple = Depends(token_required) |
|
): |
|
customer_id, user_id = credentials |
|
|
|
if not customer_id or not user_id: |
|
logging.error("Failed to extract customer_id or user_id from the JWT token.") |
|
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") |
|
|
|
logging.info("Received request to generate RAG response") |
|
try: |
|
logging.info("Starting document search") |
|
|
|
|
|
query_embedding = embed_text(body.search_query) |
|
|
|
|
|
hits, error = searcher.search_documents("documents", query_embedding, user_id) |
|
|
|
if error: |
|
logging.error(f"Search documents error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
logging.info("Generating RAG response") |
|
|
|
|
|
response, error = generate_rag_response(hits, body.search_query) |
|
|
|
if error: |
|
logging.error(f"Generate RAG response error: {error}") |
|
raise HTTPException(status_code=500, detail=error) |
|
|
|
return {"response": response} |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == '__main__': |
|
import uvicorn |
|
uvicorn.run(app, host='0.0.0.0', port=8000) |
|
|