Spaces:
Sleeping
Sleeping
# milvus.py | |
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility | |
import pandas as pd | |
import os | |
import sys | |
from sentence_transformers import SentenceTransformer | |
import time | |
# Default Milvus connection details | |
DEFAULT_MILVUS_HOST = 'localhost' | |
DEFAULT_MILVUS_PORT = '19530' | |
DEFAULT_COLLECTION_NAME = 'document_collection' | |
DEFAULT_DIMENSION = 384 # Adjust based on your embedding model | |
DEFAULT_MAX_RETRIES = 3 | |
DEFAULT_RETRY_DELAY = 5 # seconds | |
# Embedding model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
def create_milvus_collection(host, port, collection_name, dimension): | |
""" | |
Creates a new Milvus collection if it doesn't exist. | |
""" | |
if not utility.has_collection(collection_name): | |
fields = [ | |
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), | |
FieldSchema(name="path", dtype=DataType.VARCHAR, max_length=500), | |
FieldSchema(name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=dimension) | |
] | |
schema = CollectionSchema(fields, "Document Vector Store") | |
collection = Collection(collection_name, schema, consistency_level="Strong") | |
index_params = { | |
"metric_type": "L2", | |
"index_type": "IVF_FLAT", | |
"params": {"nlist": 1024} | |
} | |
collection.create_index(field_name="content_vector", index_params=index_params) | |
print(f"Collection {collection_name} created and index built.") | |
else: | |
print(f"Collection {collection_name} already exists.") | |
def load_data_to_milvus(host, port, collection_name): | |
""" | |
Loads data from the DataFrame into Milvus, using sentence embeddings. | |
""" | |
extraction_dir = "extraction" | |
pkl_files = [f for f in os.listdir(extraction_dir) if f.endswith('.pkl')] | |
if not pkl_files: | |
print("No .pkl files found in the 'extraction' directory.") | |
return | |
df_path = os.path.join(extraction_dir, pkl_files[0]) | |
df = pd.read_pickle(df_path) | |
# Generate sentence embeddings | |
df['content_vector'] = df['content'].apply(lambda x: model.encode(x).tolist()) | |
data_to_insert = [ | |
df['path'].tolist(), | |
df['content_vector'].tolist() | |
] | |
collection = Collection(collection_name) | |
collection.insert(data_to_insert) | |
collection.flush() | |
print(f"Data from {df_path} loaded into Milvus collection {collection_name}.") | |
def connect_to_milvus(host, port, max_retries, retry_delay): | |
"""Connects to Milvus with retries.""" | |
retries = 0 | |
while retries < max_retries: | |
try: | |
connections.connect(host=host, port=port) | |
print(f"Successfully connected to Milvus at {host}:{port}") | |
return True | |
except Exception as e: | |
print(f"Error connecting to Milvus: {e}") | |
retries += 1 | |
if retries < max_retries: | |
print(f"Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
print("Max retries reached. Could not connect to Milvus.") | |
return False | |
def initialize_milvus(host, port, collection_name, dimension, max_retries, retry_delay): | |
"""Initializes Milvus with parameters. | |
Returns: | |
True if successfully connected and initialized, False otherwise. | |
""" | |
if connect_to_milvus(host, port, max_retries, retry_delay): | |
try: | |
create_milvus_collection(host, port, collection_name, dimension) | |
load_data_to_milvus(host, port, collection_name) | |
connections.disconnect(alias='default') | |
return True # Return True if everything is successful | |
except Exception as e: | |
print(f"Error during initialization: {e}") | |
return False # Return False if any error occurs during collection creation or data loading | |
else: | |
return False # Return False if connection failed | |
if __name__ == "__main__": | |
# Use default values or environment variables if available | |
milvus_host = os.environ.get('MILVUS_HOST', DEFAULT_MILVUS_HOST) | |
milvus_port = os.environ.get('MILVUS_PORT', DEFAULT_MILVUS_PORT) | |
collection_name = os.environ.get('COLLECTION_NAME', DEFAULT_COLLECTION_NAME) | |
dimension = int(os.environ.get('DIMENSION', DEFAULT_DIMENSION)) | |
max_retries = int(os.environ.get('MAX_RETRIES', DEFAULT_MAX_RETRIES)) | |
retry_delay = int(os.environ.get('RETRY_DELAY', DEFAULT_RETRY_DELAY)) | |
initialize_milvus(milvus_host, milvus_port, collection_name, dimension, max_retries, retry_delay) |