rag-chat / main.py
pvanand's picture
Create main.py
63df3f2 verified
raw
history blame
3.6 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import json
import os
import logging
from txtai.embeddings import Embeddings
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"})
class DocumentRequest(BaseModel):
index_id: str
documents: List[str]
class QueryRequest(BaseModel):
index_id: str
query: str
num_results: int
def save_embeddings(index_id, document_list):
try:
folder_path = f"indexes/{index_id}"
os.makedirs(folder_path, exist_ok=True)
# Save embeddings
embeddings.save(f"{folder_path}/embeddings")
# Save document_list
with open(f"{folder_path}/document_list.json", "w") as f:
json.dump(document_list, f)
logger.info(f"Embeddings and document list saved for index_id: {index_id}")
except Exception as e:
logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}")
def load_embeddings(index_id):
try:
folder_path = f"indexes/{index_id}"
if not os.path.exists(folder_path):
logger.error(f"Index not found for index_id: {index_id}")
raise HTTPException(status_code=404, detail="Index not found")
# Load embeddings
embeddings.load(f"{folder_path}/embeddings")
# Load document_list
with open(f"{folder_path}/document_list.json", "r") as f:
document_list = json.load(f)
logger.info(f"Embeddings and document list loaded for index_id: {index_id}")
return document_list
except Exception as e:
logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}")
@app.post("/create_index/")
async def create_index(request: DocumentRequest):
try:
document_list = [(i, text, None) for i, text in enumerate(request.documents)]
embeddings.index(document_list)
save_embeddings(request.index_id, request.documents) # Save the original documents
logger.info(f"Index created successfully for index_id: {request.index_id}")
return {"message": "Index created successfully"}
except Exception as e:
logger.error(f"Error creating index: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}")
@app.post("/query_index/")
async def query_index(request: QueryRequest):
try:
document_list = load_embeddings(request.index_id)
results = embeddings.search(request.query, request.num_results)
queried_texts = [document_list[idx[0]] for idx in results]
logger.info(f"Query executed successfully for index_id: {request.index_id}")
return {"queried_texts": queried_texts}
except Exception as e:
logger.error(f"Error querying index: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)