import gradio as gr
from sentence_transformers import SentenceTransformer, util
import pandas as pd
from datasets import load_dataset
from annoy import AnnoyIndex
import os

try:
    # Load the dataset (Italian subset, test split)
    dataset = load_dataset("PhilipMay/stsb_multi_mt", name="it", split="test")
    df = pd.DataFrame(dataset)

    # Extract sentences (sentence1 and sentence2)
    sentences1 = df["sentence1"].tolist()
    sentences2 = df["sentence2"].tolist()

    # Sentence-transformers models to test
    model_names = [
        "nickprock/multi-sentence-BERTino",
        "nickprock/sentence-bert-base-italian-uncased",
        "nickprock/static-similarity-mmarco3m-mrl-BERTino-v1.5",
        "nickprock/Italian-ModernBERT-base-embed-mmarco-mnrl",
    ]

    models = {name: SentenceTransformer(name) for name in model_names}
    annoy_indexes1 = {}  # Store Annoy indexes for sentence1
    annoy_indexes2 = {}  # Store Annoy indexes for sentence2

    def build_annoy_index(model_name, sentences):
        """Builds an Annoy index for a given model and sentences."""
        model = models[model_name]
        embeddings = model.encode(sentences)
        embedding_dim = embeddings.shape[1]
        annoy_index = AnnoyIndex(embedding_dim, "angular")  # Use angular distance for cosine similarity
        for i, embedding in enumerate(embeddings):
            annoy_index.add_item(i, embedding)
        annoy_index.build(10)  # Build with 10 trees
        return annoy_index

    # Build Annoy indexes for each model
    for model_name in model_names:
        annoy_indexes1[model_name] = build_annoy_index(model_name, sentences1)
        annoy_indexes2[model_name] = build_annoy_index(model_name, sentences2)

    def find_similar_sentence_annoy(sentence, model_name, sentence_list, annoy_index):
        """Finds the most similar sentence using Annoy."""
        model = models[model_name]
        sentence_embedding = model.encode(sentence)
        nearest_neighbors = annoy_index[model_name].get_nns_by_vector(sentence_embedding, 1)
        best_sentence_index = nearest_neighbors[0]
        return sentence_list[best_sentence_index]

    def calculate_cosine_similarity(sentence1, sentence2, model):
        """Calculates the cosine similarity between two sentences."""
        embedding1 = model.encode(sentence1)
        embedding2 = model.encode(sentence2)
        return util.cos_sim(embedding1, embedding2).item()

    def compare_models_annoy(sentence, model1_name, model2_name, model3_name, model4_name):
        """Compares the results of different models using Annoy."""
        sentence1_results = {}
        sentence2_results = {}
        similarities = {}

        sentence1_results[model1_name] = find_similar_sentence_annoy(
            sentence, model1_name, sentences1, annoy_indexes1
        )
        sentence1_results[model2_name] = find_similar_sentence_annoy(
            sentence, model2_name, sentences1, annoy_indexes1
        )
        sentence1_results[model3_name] = find_similar_sentence_annoy(
            sentence, model3_name, sentences1, annoy_indexes1
        )
        sentence1_results[model4_name] = find_similar_sentence_annoy(
            sentence, model4_name, sentences1, annoy_indexes1
        )

        sentence2_results[model1_name] = find_similar_sentence_annoy(
            sentence, model1_name, sentences2, annoy_indexes2
        )
        sentence2_results[model2_name] = find_similar_sentence_annoy(
            sentence, model2_name, sentences2, annoy_indexes2
        )
        sentence2_results[model3_name] = find_similar_sentence_annoy(
            sentence, model3_name, sentences2, annoy_indexes2
        )
        sentence2_results[model4_name] = find_similar_sentence_annoy(
            sentence, model4_name, sentences2, annoy_indexes2
        )

        # Calculate cosine similarities
        for model_name in model_names:
            similarities[model_name] = calculate_cosine_similarity(
                sentence1_results[model_name], sentence2_results[model_name], models[model_name]
            )

        return sentence1_results, sentence2_results, similarities

    def format_results(sentence1_results, sentence2_results, similarities):
        """Formats the results for display in Gradio."""
        output_text = ""
        for model_name in model_names:
            output_text += f"**{model_name}**\n"
            output_text += (
                f"Most Similar Sentence from sentence1: {sentence1_results[model_name]}\n"
            )
            output_text += (
                f"Most Similar Sentence from sentence2: {sentence2_results[model_name]}\n"
            )
            output_text += f"Cosine Similarity: {similarities[model_name]:.4f}\n\n"
        return output_text

    def gradio_interface(sentence, model1_name, model2_name, model3_name, model4_name):
        """Gradio interface function."""
        sentence1_results, sentence2_results, similarities = compare_models_annoy(
            sentence, model1_name, model2_name, model3_name, model4_name
        )
        return format_results(sentence1_results, sentence2_results, similarities)

    iface = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Textbox(lines=2, placeholder="Enter your sentence here..."),
            gr.Dropdown(model_names, value=model_names[0], label="Model 1"),
            gr.Dropdown(model_names, value=model_names[1], label="Model 2"),
            gr.Dropdown(model_names, value=model_names[2], label="Model 3"),
            gr.Dropdown(model_names, value=model_names[3], label="Model 4"),
        ],
        outputs=gr.Markdown(),
        title="Sentence Transformer Model Comparison (Annoy)",
        description=(
            "Inserisce una frase e confronta le frasi più simili generate da diversi modelli "
            "sentence-transformer (utilizzando Annoy per una ricerca più veloce) sia dalla frase1 "
            "che dalla frase2. Calcola anche la similarità del coseno tra le frasi. "
            "Utilizza sentence-transformers per l'italiano e lo split test del dataset stsb_multi_mt."
        ),
    )

    iface.launch()

except Exception as e:
    print(f"Error loading dataset: {e}")
    iface = gr.Interface(
        fn=lambda: "Dataset loading failed. Check console for details.",
        inputs=[],
        outputs=gr.Textbox(),
        title="Dataset Loading Error",
        description="There was an error loading the dataset.",
    )
    iface.launch()