Mg_Alloy_Knowledgebase_v2 / vectorstore.py
tabesink92's picture
shipgit add .
2dbd03f
raw
history blame
7.3 kB
import getpass
import os
import pickle
import tqdm
import yaml
import sys
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from langchain_huggingface import HuggingFaceEmbeddings
# Load environment variables from .env
from dotenv import load_dotenv
load_dotenv()
# Initialize embedding model
model_id = "Snowflake/snowflake-arctic-embed-l"
EMBEDDINGS = HuggingFaceEmbeddings(
model_name=model_id,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
SPARSE_EMBEDDINGS = FastEmbedSparse(model_name="Qdrant/BM25")
class VectorStore:
def __init__(self, collection_name="testCollection"):
self.collection_name = collection_name
self.client = QdrantClient(":memory:")
# create the collection if it doesn't exist
if not self.client.collection_exists(collection_name):
self.client.create_collection(
collection_name=collection_name,
vectors_config={
"dense_vector": models.VectorParams(
size=1024, distance=models.Distance.COSINE # arctic embed dim: 1024
)
},
sparse_vectors_config={
"sparse_vector": models.SparseVectorParams(
index=models.SparseIndexParams(
on_disk=False,
)
)
},
)
print(f"\nCollection {collection_name} created")
else:
print(f"\nLoading existing collection: {collection_name}")
self._vector_store = self._as_vector_store(collection_name)
def get_collection_documents(self):
"""get all documents in the collection"""
records = self._vector_store.client.retrieve(
ids=list(range(1, self.client.count(self.collection_name).count + 1)),
collection_name=self.collection_name,
with_payload=True
)
documents = []
for record in records:
documents.append(Document(page_content=record.payload['page_content'], metadata=record.payload['metadata']))
return documents
def _as_vector_store(self, collection_name):
return QdrantVectorStore(
client=self.client,
collection_name=collection_name,
embedding=EMBEDDINGS,
sparse_embedding=SPARSE_EMBEDDINGS,
retrieval_mode=RetrievalMode.HYBRID,
vector_name="dense_vector",
sparse_vector_name="sparse_vector",
)
def as_retriever(self, k=3):
return self._vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": k, "lambda_mult": 0.5},
)
def add_documents(self, documents, batch_size=4):
"""add documents to the collection"""
# Skip if no documents to add
if not documents:
return
# get the number of points in the collection
point_count = self.client.count(self.collection_name)
# create a list of ids for the documents
ids = list(range(1, point_count.count))
# Get the existing documents in the collection
records = self._vector_store.client.retrieve(
ids=ids,
collection_name=self.collection_name,
with_payload=True
)
# Extract unique titles from metadata
existing_docs = list(set([record.payload['metadata']['filename'] for record in records]))
# Filter out documents that already exist
documents = [doc for doc in documents if doc.metadata["filename"] not in existing_docs]
# Skip if all documents already exist
if not documents:
print("All documents already exist in collection. Skipping upload.")
return
# create a list of ids for the documents
ids = list(range(point_count.count + 1, point_count.count + len(documents) + 1))
# add the documents to the collection
self._vector_store.add_documents(documents=documents, ids=ids)
@staticmethod
def load_chunks_as_documents(path):
file_list = []
if os.path.isfile(path) and path.endswith('.pkl'):
# Single pkl file
file_list.append(path)
elif os.path.isdir(path):
# Directory of pkl files
for filename in os.listdir(path):
if filename.endswith('.pkl'):
path_ = os.path.join(path, filename)
file_list.append(path_)
loaded_chunk_data = {}
for file in file_list:
with open(file, 'rb') as f:
data = pickle.load(f)
loaded_chunk_data[data["filename"]] = data["chunks"]
print(f"Loaded {len(loaded_chunk_data)} documents from {path}:")
for i, doc_name in enumerate(loaded_chunk_data.keys()):
print(f" {i+1}. {doc_name}")
# Convert the chunks to langhcain documents
documents = []
for fname in loaded_chunk_data.keys():
chunks = loaded_chunk_data[fname]
for chunk in chunks:
documents.append(
Document(
page_content=chunk.page_content,
metadata=chunk.metadata
)
)
return documents
def inspect_collection(self):
"""inspect the collection"""
print(f"Collection {self.collection_name} has {self.client.count(self.collection_name).count} documents")
# Get the existing documents in the collection
point_count = self.client.count(self.collection_name)
ids = list(range(1, point_count.count + 1))
records = self._vector_store.client.retrieve(
ids=ids,
collection_name=self.collection_name,
with_payload=True
)
# Extract unique titles from metadata
existing_docs = list(set([record.payload['metadata']['filename'] for record in records]))
print(f"Documents in collection:")
for i, doc_name in enumerate(existing_docs):
print(f" {i+1}. {doc_name}")
""" def main():
collection_name = input("\nEnter a collection name to add documents:").strip()
if not collection_name:
collection_name = "testCollection"
# Load the documents
if not os.path.exists(configs["CONTEXTUAL_CHUNKS_FOLDER"]):
print(f"Error: {configs['CONTEXTUAL_CHUNKS_FOLDER']} does not exist")
sys.exit(1)
documents = VectorStore.load_chunks_as_documents(configs["CONTEXTUAL_CHUNKS_FOLDER"])
# Initialize the vector store
vector_store = VectorStore(collection_name)
# Add the documents to the vector store
vector_store.add_documents(documents) """