Milvus-Server / main.py
ruslanmv's picture
updates
f72d2d6
raw
history blame
4.97 kB
from io import BytesIO
from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
import os
import pypdf
from uuid import uuid4
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import torch
from milvus_singleton import MilvusClientSingleton
# Set environment variables for Hugging Face cache
os.environ['HF_HOME'] = '/app/cache'
os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
# Embedding model
embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
trust_remote_code=True,
device='cuda' if torch.cuda.is_available() else 'cpu',
cache_folder='/app/cache')
# Milvus connection details
collection_name="rag"
milvus_uri = os.getenv("MILVUS_URI", "sqlite:///$MILVUS_DATA_DIR/milvus_demo.db")
# Initialize Milvus client using singleton
milvus_client = MilvusClientSingleton.get_instance(uri=milvus_uri)
def document_to_embeddings(content:str) -> list:
return embedding_model.encode(content, show_progress_bar=True)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with allowed origins for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def split_documents(document_data):
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=10)
return splitter.split_text(document_data)
def create_a_collection(milvus_client, collection_name):
# Define the fields for the collection
id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True)
content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
# Define the schema for the collection
schema = CollectionSchema(fields=[id_field, content_field, vector_field])
# Create the collection
milvus_client.create_collection(
collection_name=collection_name,
schema=schema
)
connections.connect(uri=milvus_uri)
collection = Collection(name=collection_name)
# Create an index for the collection
# IVF_FLAT index is used here, with metric_type COSINE
index_params = {
"index_type": "FLAT",
"metric_type": "COSINE",
"params": {
"nlist": 128
}
}
# Create the index on the vector field
collection.create_index(
field_name="vector",
index_params=index_params
)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/insert")
async def insert(file: UploadFile = File(...)):
contents = await file.read()
if not milvus_client.has_collection(collection_name):
create_a_collection(milvus_client, collection_name)
contents = pypdf.PdfReader(BytesIO(contents))
extracted_text = ""
for page_num in range(len(contents.pages)):
page = contents.pages[page_num]
extracted_text += page.extract_text()
splitted_document_data = split_documents(extracted_text)
print(splitted_document_data)
data_objects = []
for doc in splitted_document_data:
data = {
"id": str(uuid4()),
"vector": document_to_embeddings(doc),
"content": doc,
}
data_objects.append(data)
print(data_objects)
try:
milvus_client.insert(collection_name=collection_name, data=data_objects)
except Exception as e:
raise JSONResponse(status_code=500, content={"error": str(e)})
else:
return JSONResponse(status_code=200, content={"result": 'good'})
class RAGRequest(BaseModel):
question: str
@app.post("/rag")
async def rag(request: RAGRequest):
question = request.question
if not question:
return JSONResponse(status_code=400, content={"message": "Please a question!"})
try:
search_res = milvus_client.search(
collection_name=collection_name,
data=[
document_to_embeddings(question)
],
limit=5, # Return top 3 results
# search_params={"metric_type": "COSINE"}, # Inner product distance
output_fields=["content"], # Return the text field
)
retrieved_lines_with_distances = [
(res["entity"]["content"]) for res in search_res[0]
]
return JSONResponse(status_code=200, content={"result": retrieved_lines_with_distances})
except Exception as e:
return JSONResponse(status_code=400, content={"error": str(e)})