import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Optional, Tuple
from langchain_community.graphs import Neo4jGraph
import pickle

class FAISSVectorStore:
    def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False):
        self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None
        self.index = faiss.IndexFlatIP(dimension)
        self.dimension = dimension
        if embedding_file:
            self.load_embeddings(embedding_file)

    def load_embeddings(self, file_path: str):
        if file_path.endswith('.pkl'):
            with open(file_path, 'rb') as f:
                embeddings = pickle.load(f)
        elif file_path.endswith('.npy'):
            embeddings = np.load(file_path)
        else:
            raise ValueError("Unsupported file format. Use .pkl or .npy")
        
        self.add_embeddings(embeddings)

    def add_embeddings(self, embeddings: np.ndarray):
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings)

    def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None):
        query_vector = self.model.encode([query])
        faiss.normalize_L2(query_vector)
        
        if use_mmr:
            return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types)
        else:
            return self._simple_search(query_vector, k, neo4j_graph, doc_types)

    def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]:
        distances, indices = self.index.search(query_vector, k)
        
        results = []
        results_idx = []
        for i, idx in enumerate(indices[0]):
            document = self._get_text_by_index(neo4j_graph, idx, doc_types)
            if document is not None:
                results.append({
                    'document': document,
                    'score': distances[0][i]
                })
                results_idx.append(idx)
        
        return results, results_idx

    def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]:
        initial_k = min(k * 2, self.index.ntotal)
        distances, indices = self.index.search(query_vector, initial_k)
        
        # Reconstruct embeddings for the initial results
        initial_embeddings = self._reconstruct_embeddings(indices[0])
        
        selected_indices = []
        unselected_indices = list(range(len(indices[0])))
        
        for _ in range(min(k, len(indices[0]))):
            mmr_scores = []
            for i in unselected_indices:
                if not selected_indices:
                    mmr_scores.append((i, distances[0][i]))
                else:
                    embedding_i = initial_embeddings[i]
                    redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices)
                    mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy))
            
            selected_idx = max(mmr_scores, key=lambda x: x[1])[0]
            selected_indices.append(selected_idx)
            unselected_indices.remove(selected_idx)
        
        results = []
        results_idx = []
        for idx in selected_indices:
            document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types)
            if document is not None:
                results.append({
                    'document': document,
                    'score': distances[0][idx]
                })
                results_idx.append(idx)
        
        return results, results_idx

    def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray:
        return self.index.reconstruct_batch(indices)

    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    
    def _get_text_by_index(self, neo4j_graph, index, doc_types):
        if doc_types is None:
            query = f"""
            MATCH (n)
            WHERE n.id = $index
            RETURN n AS document, labels(n) AS document_type, n.id AS node_id
            """
            result = neo4j_graph.query(query, {"index": index})
        else:
            for doc_type in doc_types:
                query = f"""
                MATCH (n:{doc_type})
                WHERE n.id = $index
                RETURN n AS document, labels(n) AS document_type, n.id AS node_id
                """
                result = neo4j_graph.query(query, {"index": index})
                if result:
                    break

        if result:
            return f"[{result[0]['document_type'][0]}] {result[0]['document']}"
        return None


# Example usage
if __name__ == "__main__":
    # Initialize the vector store with embedding file
    vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl")  # or .npy file

    # Initialize Neo4jGraph (replace with your actual Neo4j connection details)
    neo4j_graph = Neo4jGraph(
        url="bolt://localhost:7687",
        username="neo4j",
        password="password"
    )

    # Perform a similarity search with and without MMR
    query = "How to start a long journey"
    results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph)
    results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph)

    # Print the results
    print(f"Top 5 similar texts for query: '{query}' (without MMR)")
    for i, result in enumerate(results_simple, 1):
        print(f"{i}. Text: {result['text']}")
        print(f"   Score: {result['score']}")
        print()

    print(f"Top 5 similar texts for query: '{query}' (with MMR)")
    for i, result in enumerate(results_mmr, 1):
        print(f"{i}. Text: {result['text']}")
        print(f"   Score: {result['score']}")
        print()