Last commit not found
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List | |
import json | |
import os | |
import logging | |
from txtai.embeddings import Embeddings | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Enable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"}) | |
class DocumentRequest(BaseModel): | |
index_id: str | |
documents: List[str] | |
class QueryRequest(BaseModel): | |
index_id: str | |
query: str | |
num_results: int | |
def save_embeddings(index_id, document_list): | |
try: | |
folder_path = f"indexes/{index_id}" | |
os.makedirs(folder_path, exist_ok=True) | |
# Save embeddings | |
embeddings.save(f"{folder_path}/embeddings") | |
# Save document_list | |
with open(f"{folder_path}/document_list.json", "w") as f: | |
json.dump(document_list, f) | |
logger.info(f"Embeddings and document list saved for index_id: {index_id}") | |
except Exception as e: | |
logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}") | |
def load_embeddings(index_id): | |
try: | |
folder_path = f"indexes/{index_id}" | |
if not os.path.exists(folder_path): | |
logger.error(f"Index not found for index_id: {index_id}") | |
raise HTTPException(status_code=404, detail="Index not found") | |
# Load embeddings | |
embeddings.load(f"{folder_path}/embeddings") | |
# Load document_list | |
with open(f"{folder_path}/document_list.json", "r") as f: | |
document_list = json.load(f) | |
logger.info(f"Embeddings and document list loaded for index_id: {index_id}") | |
return document_list | |
except Exception as e: | |
logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}") | |
async def create_index(request: DocumentRequest): | |
try: | |
document_list = [(i, text, None) for i, text in enumerate(request.documents)] | |
embeddings.index(document_list) | |
save_embeddings(request.index_id, request.documents) # Save the original documents | |
logger.info(f"Index created successfully for index_id: {request.index_id}") | |
return {"message": "Index created successfully"} | |
except Exception as e: | |
logger.error(f"Error creating index: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}") | |
async def query_index(request: QueryRequest): | |
try: | |
document_list = load_embeddings(request.index_id) | |
results = embeddings.search(request.query, request.num_results) | |
queried_texts = [document_list[idx[0]] for idx in results] | |
logger.info(f"Query executed successfully for index_id: {request.index_id}") | |
return {"queried_texts": queried_texts} | |
except Exception as e: | |
logger.error(f"Error querying index: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |