bjhk / doc_explorer /vectorstore.py
heymenn's picture
Upload 15 files
6aaddef verified
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Optional, Tuple
from langchain_community.graphs import Neo4jGraph
import pickle
class FAISSVectorStore:
def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False):
self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None
self.index = faiss.IndexFlatIP(dimension)
self.dimension = dimension
if embedding_file:
self.load_embeddings(embedding_file)
def load_embeddings(self, file_path: str):
if file_path.endswith('.pkl'):
with open(file_path, 'rb') as f:
embeddings = pickle.load(f)
elif file_path.endswith('.npy'):
embeddings = np.load(file_path)
else:
raise ValueError("Unsupported file format. Use .pkl or .npy")
self.add_embeddings(embeddings)
def add_embeddings(self, embeddings: np.ndarray):
faiss.normalize_L2(embeddings)
self.index.add(embeddings)
def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None):
query_vector = self.model.encode([query])
faiss.normalize_L2(query_vector)
if use_mmr:
return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types)
else:
return self._simple_search(query_vector, k, neo4j_graph, doc_types)
def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]:
distances, indices = self.index.search(query_vector, k)
results = []
results_idx = []
for i, idx in enumerate(indices[0]):
document = self._get_text_by_index(neo4j_graph, idx, doc_types)
if document is not None:
results.append({
'document': document,
'score': distances[0][i]
})
results_idx.append(idx)
return results, results_idx
def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]:
initial_k = min(k * 2, self.index.ntotal)
distances, indices = self.index.search(query_vector, initial_k)
# Reconstruct embeddings for the initial results
initial_embeddings = self._reconstruct_embeddings(indices[0])
selected_indices = []
unselected_indices = list(range(len(indices[0])))
for _ in range(min(k, len(indices[0]))):
mmr_scores = []
for i in unselected_indices:
if not selected_indices:
mmr_scores.append((i, distances[0][i]))
else:
embedding_i = initial_embeddings[i]
redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices)
mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy))
selected_idx = max(mmr_scores, key=lambda x: x[1])[0]
selected_indices.append(selected_idx)
unselected_indices.remove(selected_idx)
results = []
results_idx = []
for idx in selected_indices:
document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types)
if document is not None:
results.append({
'document': document,
'score': distances[0][idx]
})
results_idx.append(idx)
return results, results_idx
def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray:
return self.index.reconstruct_batch(indices)
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def _get_text_by_index(self, neo4j_graph, index, doc_types):
if doc_types is None:
query = f"""
MATCH (n)
WHERE n.id = $index
RETURN n AS document, labels(n) AS document_type, n.id AS node_id
"""
result = neo4j_graph.query(query, {"index": index})
else:
for doc_type in doc_types:
query = f"""
MATCH (n:{doc_type})
WHERE n.id = $index
RETURN n AS document, labels(n) AS document_type, n.id AS node_id
"""
result = neo4j_graph.query(query, {"index": index})
if result:
break
if result:
return f"[{result[0]['document_type'][0]}] {result[0]['document']}"
return None
# Example usage
if __name__ == "__main__":
# Initialize the vector store with embedding file
vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl") # or .npy file
# Initialize Neo4jGraph (replace with your actual Neo4j connection details)
neo4j_graph = Neo4jGraph(
url="bolt://localhost:7687",
username="neo4j",
password="password"
)
# Perform a similarity search with and without MMR
query = "How to start a long journey"
results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph)
results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph)
# Print the results
print(f"Top 5 similar texts for query: '{query}' (without MMR)")
for i, result in enumerate(results_simple, 1):
print(f"{i}. Text: {result['text']}")
print(f" Score: {result['score']}")
print()
print(f"Top 5 similar texts for query: '{query}' (with MMR)")
for i, result in enumerate(results_mmr, 1):
print(f"{i}. Text: {result['text']}")
print(f" Score: {result['score']}")
print()