import numpy as np
import fasttext
import fasttext.util
import pandas as pd
import random
import normalizer
from transformers import pipeline
from sklearn.metrics.pairwise import cosine_similarity
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM

random.seed(42)

tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/albert-fa-zwnj-base-v2")
# model = AutoModelForMaskedLM.from_pretrained("HooshvareLab/albert-fa-zwnj-base-v2")


# Load pre-trained word embeddings (e.g., fasttext)
fasttext.util.download_model('fa', if_exists='ignore')  # English
embeddings = fasttext.load_model(r'cc.fa.300.bin')

# Example sentences with masked tokens
# masked_sentences = [
#     ("The capital of France is [MASK].", "Paris"),
#     ("The [MASK] is the largest mammal.", "whale"),
#     ("The fastest land animal is the [MASK].", "cheetah")
# ]

# df = pd.read_excel('law_excel.xlsx', sheet_name='Sheet1')
# dataset = Dataset.from_pandas(df)
dataset = load_dataset('community-datasets/farsi_news', split='hamshahri')
dataset = dataset.shuffle(seed=42).select(range(100))

def tokenize_dataset(examples):
    result = tokenizer(examples['summary'])

    temp = {'masked_token': [-1] * len(result['input_ids']), 'input_ids': result['input_ids']}
    for i, example in enumerate(result['input_ids']):
        
        rand = random.randint(1, len(example)-2)
        temp['masked_token'][i] = tokenizer.decode(example[rand])
        temp['input_ids'][i][rand] = 4
    
    result['input_ids'] = temp['input_ids']
    result['masked_token'] = temp['masked_token']

    return result

dataset = dataset.map(tokenize_dataset, batched=True)


# Initialize the fill-mask pipeline
fill_mask = pipeline("fill-mask", model="HooshvareLab/albert-fa-zwnj-base-v2")

# Define k for top-k predictions
k = 5
# Define similarity threshold
similarity_threshold = 0.5

# Initialize counters
TPP = 0
FPP = 0

FNR = 0
TPR = 0

def get_embedding(word):
    try:
        return embeddings[word]
    except KeyError:
        return None

for _, data in enumerate(dataset.iter(1)):
    sentence = tokenizer.decode(data['input_ids'][0][1:-1])
    sentence = normalizer.cleaning(sentence)
    ground_truth = data['masked_token'][0]

    # Get top-k predictions
    predictions = fill_mask(sentence)[:k]
    predicted_tokens = [pred['token_str'] for pred in predictions]

    ground_truth_emb = get_embedding(ground_truth)
    
    if ground_truth_emb is None:
        continue  # Skip if ground truth is not in the embeddings

    flag = False
    for token in predicted_tokens:
        token_emb = get_embedding(token)
        if token_emb is not None:
            similarity = cosine_similarity([ground_truth_emb], [token_emb])[0][0]
            if similarity >= similarity_threshold:
                TPP += 1
                flag = True
            else:
                FPP += 1
    if flag:
        TPR += 1
    else:
        FNR += 1
            

# Compute precision and recall
precision = TPP / (TPP + FPP) if (TPP + FPP) > 0 else 0
recall = TPR / (TPR + FNR) if (TPR + FNR) > 0 else 0

print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

result = {'model': "HooshvareLab/albert-fa-zwnj-base-v2",
          'evaluation_dataset': 'allenai/c4',
          'Recall': recall,
          'Precision': precision,
          'F1': (recall*precision) / (recall + precision)}

result = pd.DataFrame([result])

result.to_csv('result.csv', index=False)