import faiss import numpy as np from sentence_transformers import SentenceTransformer from typing import List, Optional, Tuple from langchain_community.graphs import Neo4jGraph import pickle class FAISSVectorStore: def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False): self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None self.index = faiss.IndexFlatIP(dimension) self.dimension = dimension if embedding_file: self.load_embeddings(embedding_file) def load_embeddings(self, file_path: str): if file_path.endswith('.pkl'): with open(file_path, 'rb') as f: embeddings = pickle.load(f) elif file_path.endswith('.npy'): embeddings = np.load(file_path) else: raise ValueError("Unsupported file format. Use .pkl or .npy") self.add_embeddings(embeddings) def add_embeddings(self, embeddings: np.ndarray): faiss.normalize_L2(embeddings) self.index.add(embeddings) def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None): query_vector = self.model.encode([query]) faiss.normalize_L2(query_vector) if use_mmr: return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types) else: return self._simple_search(query_vector, k, neo4j_graph, doc_types) def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]: distances, indices = self.index.search(query_vector, k) results = [] results_idx = [] for i, idx in enumerate(indices[0]): document = self._get_text_by_index(neo4j_graph, idx, doc_types) if document is not None: results.append({ 'document': document, 'score': distances[0][i] }) results_idx.append(idx) return results, results_idx def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]: initial_k = min(k * 2, self.index.ntotal) distances, indices = self.index.search(query_vector, initial_k) # Reconstruct embeddings for the initial results initial_embeddings = self._reconstruct_embeddings(indices[0]) selected_indices = [] unselected_indices = list(range(len(indices[0]))) for _ in range(min(k, len(indices[0]))): mmr_scores = [] for i in unselected_indices: if not selected_indices: mmr_scores.append((i, distances[0][i])) else: embedding_i = initial_embeddings[i] redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices) mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy)) selected_idx = max(mmr_scores, key=lambda x: x[1])[0] selected_indices.append(selected_idx) unselected_indices.remove(selected_idx) results = [] results_idx = [] for idx in selected_indices: document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types) if document is not None: results.append({ 'document': document, 'score': distances[0][idx] }) results_idx.append(idx) return results, results_idx def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray: return self.index.reconstruct_batch(indices) def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) def _get_text_by_index(self, neo4j_graph, index, doc_types): if doc_types is None: query = f""" MATCH (n) WHERE n.id = $index RETURN n AS document, labels(n) AS document_type, n.id AS node_id """ result = neo4j_graph.query(query, {"index": index}) else: for doc_type in doc_types: query = f""" MATCH (n:{doc_type}) WHERE n.id = $index RETURN n AS document, labels(n) AS document_type, n.id AS node_id """ result = neo4j_graph.query(query, {"index": index}) if result: break if result: return f"[{result[0]['document_type'][0]}] {result[0]['document']}" return None # Example usage if __name__ == "__main__": # Initialize the vector store with embedding file vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl") # or .npy file # Initialize Neo4jGraph (replace with your actual Neo4j connection details) neo4j_graph = Neo4jGraph( url="bolt://localhost:7687", username="neo4j", password="password" ) # Perform a similarity search with and without MMR query = "How to start a long journey" results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph) results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph) # Print the results print(f"Top 5 similar texts for query: '{query}' (without MMR)") for i, result in enumerate(results_simple, 1): print(f"{i}. Text: {result['text']}") print(f" Score: {result['score']}") print() print(f"Top 5 similar texts for query: '{query}' (with MMR)") for i, result in enumerate(results_mmr, 1): print(f"{i}. Text: {result['text']}") print(f" Score: {result['score']}") print()