from io import BytesIO from fastapi import FastAPI, File, UploadFile from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from pymilvus import utility, Collection, CollectionSchema, FieldSchema, DataType import os import pypdf from uuid import uuid4 from langchain.text_splitter import RecursiveCharacterTextSplitter from sentence_transformers import SentenceTransformer import torch from milvus_singleton import MilvusClientSingleton # Set environment variables for Hugging Face cache os.environ['HF_HOME'] = '/app/cache' os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules' # Embedding model embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True, device='cuda' if torch.cuda.is_available() else 'cpu', cache_folder='/app/cache') # Milvus connection details collection_name = "rag" milvus_uri = os.getenv("MILVUS_URI", "http://localhost:19530") # Correct URI for Milvus # Initialize Milvus client using singleton milvus_client = MilvusClientSingleton.get_instance(uri=milvus_uri) def document_to_embeddings(content: str) -> list: return embedding_model.encode(content, show_progress_bar=True) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Replace with allowed origins for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def split_documents(document_data): splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=10) return splitter.split_text(document_data) def create_a_collection(milvus_client, collection_name): # Define the fields for the collection id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True) content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096) vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024) # Define the schema for the collection schema = CollectionSchema(fields=[id_field, content_field, vector_field]) # Create the collection milvus_client.create_collection( collection_name=collection_name, schema=schema ) collection = Collection(name=collection_name) # Create an index for the collection # IVF_FLAT index is used here, with metric_type COSINE index_params = { "index_type": "FLAT", "metric_type": "COSINE", "params": { "nlist": 128 } } # Create the index on the vector field collection.create_index( field_name="vector", index_params=index_params ) @app.get("/") async def root(): return {"message": "Hello World"} @app.post("/insert") async def insert(file: UploadFile = File(...)): contents = await file.read() if not milvus_client.has_collection(collection_name): create_a_collection(milvus_client, collection_name) contents = pypdf.PdfReader(BytesIO(contents)) extracted_text = "" for page_num in range(len(contents.pages)): page = contents.pages[page_num] extracted_text += page.extract_text() splitted_document_data = split_documents(extracted_text) print(splitted_document_data) data_objects = [] for doc in splitted_document_data: data = { "id": str(uuid4()), "vector": document_to_embeddings(doc), "content": doc, } data_objects.append(data) print(data_objects) try: milvus_client.insert(collection_name=collection_name, data=data_objects) except Exception as e: raise JSONResponse(status_code=500, content={"error": str(e)}) else: return JSONResponse(status_code=200, content={"result": 'good'}) class RAGRequest(BaseModel): question: str @app.post("/rag") async def rag(request: RAGRequest): question = request.question if not question: return JSONResponse(status_code=400, content={"message": "Please a question!"}) try: search_res = milvus_client.search( collection_name=collection_name, data=[ document_to_embeddings(question) ], limit=5, # Return top 5 results # search_params={"metric_type": "COSINE"}, # Inner product distance output_fields=["content"], # Return the text field ) retrieved_lines_with_distances = [ (res["entity"]["content"]) for res in search_res[0] ] return JSONResponse(status_code=200, content={"result": retrieved_lines_with_distances}) except Exception as e: return JSONResponse(status_code=400, content={"error": str(e)})