ruslanmv commited on
Commit
623fdec
·
1 Parent(s): f72d2d6
Files changed (2) hide show
  1. main.py +7 -8
  2. milvus_singleton.py +8 -12
main.py CHANGED
@@ -1,10 +1,10 @@
1
  from io import BytesIO
2
- from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
3
  from fastapi.encoders import jsonable_encoder
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
- from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
8
  import os
9
  import pypdf
10
  from uuid import uuid4
@@ -21,16 +21,16 @@ os.environ['HF_MODULES_CACHE'] = '/app/cache/hf_modules'
21
  embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
22
  trust_remote_code=True,
23
  device='cuda' if torch.cuda.is_available() else 'cpu',
24
- cache_folder='/app/cache')
25
 
26
  # Milvus connection details
27
- collection_name="rag"
28
- milvus_uri = os.getenv("MILVUS_URI", "sqlite:///$MILVUS_DATA_DIR/milvus_demo.db")
29
 
30
  # Initialize Milvus client using singleton
31
  milvus_client = MilvusClientSingleton.get_instance(uri=milvus_uri)
32
 
33
- def document_to_embeddings(content:str) -> list:
34
  return embedding_model.encode(content, show_progress_bar=True)
35
 
36
  app = FastAPI()
@@ -60,7 +60,6 @@ def create_a_collection(milvus_client, collection_name):
60
  collection_name=collection_name,
61
  schema=schema
62
  )
63
- connections.connect(uri=milvus_uri)
64
  collection = Collection(name=collection_name)
65
  # Create an index for the collection
66
  # IVF_FLAT index is used here, with metric_type COSINE
@@ -123,7 +122,7 @@ async def rag(request: RAGRequest):
123
  data=[
124
  document_to_embeddings(question)
125
  ],
126
- limit=5, # Return top 3 results
127
  # search_params={"metric_type": "COSINE"}, # Inner product distance
128
  output_fields=["content"], # Return the text field
129
  )
 
1
  from io import BytesIO
2
+ from fastapi import FastAPI, File, UploadFile
3
  from fastapi.encoders import jsonable_encoder
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
+ from pymilvus import utility, Collection, CollectionSchema, FieldSchema, DataType
8
  import os
9
  import pypdf
10
  from uuid import uuid4
 
21
  embedding_model = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5',
22
  trust_remote_code=True,
23
  device='cuda' if torch.cuda.is_available() else 'cpu',
24
+ cache_folder='/app/cache')
25
 
26
  # Milvus connection details
27
+ collection_name = "rag"
28
+ milvus_uri = os.getenv("MILVUS_URI", "http://localhost:19530") # Correct URI for Milvus
29
 
30
  # Initialize Milvus client using singleton
31
  milvus_client = MilvusClientSingleton.get_instance(uri=milvus_uri)
32
 
33
+ def document_to_embeddings(content: str) -> list:
34
  return embedding_model.encode(content, show_progress_bar=True)
35
 
36
  app = FastAPI()
 
60
  collection_name=collection_name,
61
  schema=schema
62
  )
 
63
  collection = Collection(name=collection_name)
64
  # Create an index for the collection
65
  # IVF_FLAT index is used here, with metric_type COSINE
 
122
  data=[
123
  document_to_embeddings(question)
124
  ],
125
+ limit=5, # Return top 5 results
126
  # search_params={"metric_type": "COSINE"}, # Inner product distance
127
  output_fields=["content"], # Return the text field
128
  )
milvus_singleton.py CHANGED
@@ -1,25 +1,21 @@
1
- from pymilvus import Milvus, connections
2
  from pymilvus.exceptions import ConnectionConfigException
3
 
4
  class MilvusClientSingleton:
5
- _instance = None
6
-
7
- @staticmethod
8
- def get_instance(uri):
9
- if MilvusClientSingleton._instance is None:
10
- MilvusClientSingleton(uri)
11
- return MilvusClientSingleton._instance
12
 
13
  def __init__(self, uri):
14
  if MilvusClientSingleton._instance is not None:
15
  raise Exception("This class is a singleton!")
16
  try:
17
- # Use the regular Milvus client (not MilvusClient)
18
- self._instance = Milvus(uri=uri)
 
19
  print(f"Successfully connected to Milvus at {uri}")
20
  except ConnectionConfigException as e:
21
  print(f"Error connecting to Milvus: {e}")
22
- raise # Re-raise the exception to stop initialization
23
 
24
  def __getattr__(self, name):
25
- return getattr(self._instance, name)
 
 
1
+ from pymilvus import connections
2
  from pymilvus.exceptions import ConnectionConfigException
3
 
4
  class MilvusClientSingleton:
5
+ # ... (rest of the class code)
 
 
 
 
 
 
6
 
7
  def __init__(self, uri):
8
  if MilvusClientSingleton._instance is not None:
9
  raise Exception("This class is a singleton!")
10
  try:
11
+ # Use connections.connect()
12
+ connections.connect(uri=uri)
13
+ self._instance = connections # Store the connections object
14
  print(f"Successfully connected to Milvus at {uri}")
15
  except ConnectionConfigException as e:
16
  print(f"Error connecting to Milvus: {e}")
17
+ raise
18
 
19
  def __getattr__(self, name):
20
+ # Delegate attribute access to the default connection
21
+ return getattr(connections, name)