ruslanmv commited on
Commit
11eaf27
·
1 Parent(s): ffe9995

First commit

Browse files
Files changed (5) hide show
  1. Dockerfile +30 -0
  2. Dockerfile.milvus +32 -0
  3. app/main.py +135 -0
  4. app/milvus_singleton.py +25 -0
  5. app/requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10.8
2
+
3
+ WORKDIR /app
4
+
5
+ COPY ./app/requirements.txt /app/requirements.txt
6
+
7
+ # Create cache and milvus_data directories and set permissions
8
+ RUN mkdir -p /app/cache /app/milvus_data && chmod -R 777 /app/cache /app/milvus_data
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
+
13
+ # Create a non-root user
14
+ RUN useradd -m -u 1000 user
15
+ USER user
16
+
17
+ # Set environment variables for Hugging Face cache and Milvus data
18
+ ENV HF_HOME=/app/cache \
19
+ HF_MODULES_CACHE=/app/cache/hf_modules \
20
+ MILVUS_DATA_DIR=/app/milvus_data \
21
+ HF_WORKER_COUNT=1
22
+
23
+ # Copy the application code
24
+ COPY ./app /app
25
+
26
+ # Expose the port Uvicorn will run on
27
+ EXPOSE 7860
28
+
29
+ # Start Uvicorn
30
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
Dockerfile.milvus ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ # Install Milvus dependencies
4
+ USER root
5
+ RUN apt-get update && apt-get install -y wget ffmpeg libsm6 libxext6 libaio1
6
+
7
+ # Download and install Milvus
8
+ RUN wget https://github.com/milvus-io/milvus/releases/download/v2.3.7/milvus_2.3.7-1_amd64.deb && \
9
+ dpkg -i milvus_2.3.7-1_amd64.deb && \
10
+ apt-get -f install && \
11
+ apt-get clean && \
12
+ rm milvus_2.3.7-1_amd64.deb
13
+
14
+ # Create a directory for Milvus data
15
+ RUN mkdir -p /milvus/data
16
+
17
+ # Set up Milvus user
18
+ RUN useradd -m -u 1000 milvus
19
+ USER milvus
20
+
21
+ # Set Milvus environment variables
22
+ ENV MILVUS_HOME=/home/milvus
23
+ ENV PATH=$MILVUS_HOME/bin:$PATH
24
+
25
+ # Set working directory
26
+ WORKDIR $MILVUS_HOME
27
+
28
+ # Expose Milvus ports
29
+ EXPOSE 19530
30
+
31
+ # Start Milvus server
32
+ CMD ["milvus", "run", "standalone", "-d", "/milvus/data"]
app/main.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
3
+ from fastapi.encoders import jsonable_encoder
4
+ from fastapi.responses import JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
+ from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
8
+ import os
9
+ import pypdf
10
+ from uuid import uuid4
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from sentence_transformers import SentenceTransformer
13
+ import torch
14
+ from app.milvus_singleton import MilvusClientSingleton
15
+
16
+ # Set environment variables for Hugging Face cache
17
+ os.environ['HF_HOME'] = '/app/cache'
18
+ os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
19
+
20
+ # Embedding model
21
+ embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
22
+ trust_remote_code=True,
23
+ device='cuda' if torch.cuda.is_available() else 'cpu',
24
+ cache_folder='/app/cache')
25
+
26
+ # Milvus connection details
27
+ collection_name="rag"
28
+ milvus_uri = os.getenv("MILVUS_URI", "sqlite:///$MILVUS_DATA_DIR/milvus_demo.db")
29
+
30
+ # Initialize Milvus client using singleton
31
+ milvus_client = MilvusClientSingleton.get_instance(uri=milvus_uri)
32
+
33
+ def document_to_embeddings(content:str) -> list:
34
+ return embedding_model.encode(content, show_progress_bar=True)
35
+
36
+ app = FastAPI()
37
+
38
+ # Add CORS middleware
39
+ app.add_middleware(
40
+ CORSMiddleware,
41
+ allow_origins=["*"], # Replace with allowed origins for production
42
+ allow_credentials=True,
43
+ allow_methods=["*"],
44
+ allow_headers=["*"],
45
+ )
46
+
47
+ def split_documents(document_data):
48
+ splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=10)
49
+ return splitter.split_text(document_data)
50
+
51
+ def create_a_collection(milvus_client, collection_name):
52
+ # Define the fields for the collection
53
+ id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=40, is_primary=True)
54
+ content_field = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=4096)
55
+ vector_field = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
56
+ # Define the schema for the collection
57
+ schema = CollectionSchema(fields=[id_field, content_field, vector_field])
58
+ # Create the collection
59
+ milvus_client.create_collection(
60
+ collection_name=collection_name,
61
+ schema=schema
62
+ )
63
+ connections.connect(uri=milvus_uri)
64
+ collection = Collection(name=collection_name)
65
+ # Create an index for the collection
66
+ # IVF_FLAT index is used here, with metric_type COSINE
67
+ index_params = {
68
+ "index_type": "FLAT",
69
+ "metric_type": "COSINE",
70
+ "params": {
71
+ "nlist": 128
72
+ }
73
+ }
74
+ # Create the index on the vector field
75
+ collection.create_index(
76
+ field_name="vector",
77
+ index_params=index_params
78
+ )
79
+
80
+ @app.get("/")
81
+ async def root():
82
+ return {"message": "Hello World"}
83
+
84
+ @app.post("/insert")
85
+ async def insert(file: UploadFile = File(...)):
86
+ contents = await file.read()
87
+ if not milvus_client.has_collection(collection_name):
88
+ create_a_collection(milvus_client, collection_name)
89
+ contents = pypdf.PdfReader(BytesIO(contents))
90
+ extracted_text = ""
91
+ for page_num in range(len(contents.pages)):
92
+ page = contents.pages[page_num]
93
+ extracted_text += page.extract_text()
94
+ splitted_document_data = split_documents(extracted_text)
95
+ print(splitted_document_data)
96
+ data_objects = []
97
+ for doc in splitted_document_data:
98
+ data = {
99
+ "id": str(uuid4()),
100
+ "vector": document_to_embeddings(doc),
101
+ "content": doc,
102
+ }
103
+ data_objects.append(data)
104
+ print(data_objects)
105
+ try:
106
+ milvus_client.insert(collection_name=collection_name, data=data_objects)
107
+ except Exception as e:
108
+ raise JSONResponse(status_code=500, content={"error": str(e)})
109
+ else:
110
+ return JSONResponse(status_code=200, content={"result": 'good'})
111
+
112
+ class RAGRequest(BaseModel):
113
+ question: str
114
+
115
+ @app.post("/rag")
116
+ async def rag(request: RAGRequest):
117
+ question = request.question
118
+ if not question:
119
+ return JSONResponse(status_code=400, content={"message": "Please a question!"})
120
+ try:
121
+ search_res = milvus_client.search(
122
+ collection_name=collection_name,
123
+ data=[
124
+ document_to_embeddings(question)
125
+ ],
126
+ limit=5, # Return top 3 results
127
+ # search_params={"metric_type": "COSINE"}, # Inner product distance
128
+ output_fields=["content"], # Return the text field
129
+ )
130
+ retrieved_lines_with_distances = [
131
+ (res["entity"]["content"]) for res in search_res[0]
132
+ ]
133
+ return JSONResponse(status_code=200, content={"result": retrieved_lines_with_distances})
134
+ except Exception as e:
135
+ return JSONResponse(status_code=400, content={"error": str(e)})
app/milvus_singleton.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import Milvus, connections
2
+ from pymilvus.exceptions import ConnectionConfigException
3
+
4
+ class MilvusClientSingleton:
5
+ _instance = None
6
+
7
+ @staticmethod
8
+ def get_instance(uri):
9
+ if MilvusClientSingleton._instance is None:
10
+ MilvusClientSingleton(uri)
11
+ return MilvusClientSingleton._instance
12
+
13
+ def __init__(self, uri):
14
+ if MilvusClientSingleton._instance is not None:
15
+ raise Exception("This class is a singleton!")
16
+ try:
17
+ # Use the regular Milvus client (not MilvusClient)
18
+ self._instance = Milvus(uri=uri)
19
+ print(f"Successfully connected to Milvus at {uri}")
20
+ except ConnectionConfigException as e:
21
+ print(f"Error connecting to Milvus: {e}")
22
+ raise # Re-raise the exception to stop initialization
23
+
24
+ def __getattr__(self, name):
25
+ return getattr(self._instance, name)
app/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pypdf
4
+ python-multipart
5
+ langchain
6
+ sentence-transformers
7
+ torch
8
+ pymilvu