|
import faiss |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from typing import List, Optional, Tuple |
|
from langchain_community.graphs import Neo4jGraph |
|
import pickle |
|
|
|
class FAISSVectorStore: |
|
def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False): |
|
self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None |
|
self.index = faiss.IndexFlatIP(dimension) |
|
self.dimension = dimension |
|
if embedding_file: |
|
self.load_embeddings(embedding_file) |
|
|
|
def load_embeddings(self, file_path: str): |
|
if file_path.endswith('.pkl'): |
|
with open(file_path, 'rb') as f: |
|
embeddings = pickle.load(f) |
|
elif file_path.endswith('.npy'): |
|
embeddings = np.load(file_path) |
|
else: |
|
raise ValueError("Unsupported file format. Use .pkl or .npy") |
|
|
|
self.add_embeddings(embeddings) |
|
|
|
def add_embeddings(self, embeddings: np.ndarray): |
|
faiss.normalize_L2(embeddings) |
|
self.index.add(embeddings) |
|
|
|
def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None): |
|
query_vector = self.model.encode([query]) |
|
faiss.normalize_L2(query_vector) |
|
|
|
if use_mmr: |
|
return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types) |
|
else: |
|
return self._simple_search(query_vector, k, neo4j_graph, doc_types) |
|
|
|
def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]: |
|
distances, indices = self.index.search(query_vector, k) |
|
|
|
results = [] |
|
results_idx = [] |
|
for i, idx in enumerate(indices[0]): |
|
document = self._get_text_by_index(neo4j_graph, idx, doc_types) |
|
if document is not None: |
|
results.append({ |
|
'document': document, |
|
'score': distances[0][i] |
|
}) |
|
results_idx.append(idx) |
|
|
|
return results, results_idx |
|
|
|
def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]: |
|
initial_k = min(k * 2, self.index.ntotal) |
|
distances, indices = self.index.search(query_vector, initial_k) |
|
|
|
|
|
initial_embeddings = self._reconstruct_embeddings(indices[0]) |
|
|
|
selected_indices = [] |
|
unselected_indices = list(range(len(indices[0]))) |
|
|
|
for _ in range(min(k, len(indices[0]))): |
|
mmr_scores = [] |
|
for i in unselected_indices: |
|
if not selected_indices: |
|
mmr_scores.append((i, distances[0][i])) |
|
else: |
|
embedding_i = initial_embeddings[i] |
|
redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices) |
|
mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy)) |
|
|
|
selected_idx = max(mmr_scores, key=lambda x: x[1])[0] |
|
selected_indices.append(selected_idx) |
|
unselected_indices.remove(selected_idx) |
|
|
|
results = [] |
|
results_idx = [] |
|
for idx in selected_indices: |
|
document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types) |
|
if document is not None: |
|
results.append({ |
|
'document': document, |
|
'score': distances[0][idx] |
|
}) |
|
results_idx.append(idx) |
|
|
|
return results, results_idx |
|
|
|
def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray: |
|
return self.index.reconstruct_batch(indices) |
|
|
|
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: |
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
|
def _get_text_by_index(self, neo4j_graph, index, doc_types): |
|
if doc_types is None: |
|
query = f""" |
|
MATCH (n) |
|
WHERE n.id = $index |
|
RETURN n AS document, labels(n) AS document_type, n.id AS node_id |
|
""" |
|
result = neo4j_graph.query(query, {"index": index}) |
|
else: |
|
for doc_type in doc_types: |
|
query = f""" |
|
MATCH (n:{doc_type}) |
|
WHERE n.id = $index |
|
RETURN n AS document, labels(n) AS document_type, n.id AS node_id |
|
""" |
|
result = neo4j_graph.query(query, {"index": index}) |
|
if result: |
|
break |
|
|
|
if result: |
|
return f"[{result[0]['document_type'][0]}] {result[0]['document']}" |
|
return None |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl") |
|
|
|
|
|
neo4j_graph = Neo4jGraph( |
|
url="bolt://localhost:7687", |
|
username="neo4j", |
|
password="password" |
|
) |
|
|
|
|
|
query = "How to start a long journey" |
|
results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph) |
|
results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph) |
|
|
|
|
|
print(f"Top 5 similar texts for query: '{query}' (without MMR)") |
|
for i, result in enumerate(results_simple, 1): |
|
print(f"{i}. Text: {result['text']}") |
|
print(f" Score: {result['score']}") |
|
print() |
|
|
|
print(f"Top 5 similar texts for query: '{query}' (with MMR)") |
|
for i, result in enumerate(results_mmr, 1): |
|
print(f"{i}. Text: {result['text']}") |
|
print(f" Score: {result['score']}") |
|
print() |