File size: 3,481 Bytes
3535e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import fasttext
import fasttext.util
import pandas as pd
import random
import normalizer
from transformers import pipeline
from sklearn.metrics.pairwise import cosine_similarity
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM

random.seed(42)

tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/albert-fa-zwnj-base-v2")
# model = AutoModelForMaskedLM.from_pretrained("HooshvareLab/albert-fa-zwnj-base-v2")


# Load pre-trained word embeddings (e.g., fasttext)
fasttext.util.download_model('fa', if_exists='ignore')  # English
embeddings = fasttext.load_model(r'cc.fa.300.bin')

# Example sentences with masked tokens
# masked_sentences = [
#     ("The capital of France is [MASK].", "Paris"),
#     ("The [MASK] is the largest mammal.", "whale"),
#     ("The fastest land animal is the [MASK].", "cheetah")
# ]

# df = pd.read_excel('law_excel.xlsx', sheet_name='Sheet1')
# dataset = Dataset.from_pandas(df)
dataset = load_dataset('community-datasets/farsi_news', split='hamshahri')
dataset = dataset.shuffle(seed=42).select(range(100))

def tokenize_dataset(examples):
    result = tokenizer(examples['summary'])

    temp = {'masked_token': [-1] * len(result['input_ids']), 'input_ids': result['input_ids']}
    for i, example in enumerate(result['input_ids']):
        
        rand = random.randint(1, len(example)-2)
        temp['masked_token'][i] = tokenizer.decode(example[rand])
        temp['input_ids'][i][rand] = 4
    
    result['input_ids'] = temp['input_ids']
    result['masked_token'] = temp['masked_token']

    return result

dataset = dataset.map(tokenize_dataset, batched=True)


# Initialize the fill-mask pipeline
fill_mask = pipeline("fill-mask", model="HooshvareLab/albert-fa-zwnj-base-v2")

# Define k for top-k predictions
k = 5
# Define similarity threshold
similarity_threshold = 0.5

# Initialize counters
TPP = 0
FPP = 0

FNR = 0
TPR = 0

def get_embedding(word):
    try:
        return embeddings[word]
    except KeyError:
        return None

for _, data in enumerate(dataset.iter(1)):
    sentence = tokenizer.decode(data['input_ids'][0][1:-1])
    sentence = normalizer.cleaning(sentence)
    ground_truth = data['masked_token'][0]

    # Get top-k predictions
    predictions = fill_mask(sentence)[:k]
    predicted_tokens = [pred['token_str'] for pred in predictions]

    ground_truth_emb = get_embedding(ground_truth)
    
    if ground_truth_emb is None:
        continue  # Skip if ground truth is not in the embeddings

    flag = False
    for token in predicted_tokens:
        token_emb = get_embedding(token)
        if token_emb is not None:
            similarity = cosine_similarity([ground_truth_emb], [token_emb])[0][0]
            if similarity >= similarity_threshold:
                TPP += 1
                flag = True
            else:
                FPP += 1
    if flag:
        TPR += 1
    else:
        FNR += 1
            

# Compute precision and recall
precision = TPP / (TPP + FPP) if (TPP + FPP) > 0 else 0
recall = TPR / (TPR + FNR) if (TPR + FNR) > 0 else 0

print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

result = {'model': "HooshvareLab/albert-fa-zwnj-base-v2",
          'evaluation_dataset': 'allenai/c4',
          'Recall': recall,
          'Precision': precision,
          'F1': (recall*precision) / (recall + precision)}

result = pd.DataFrame([result])

result.to_csv('result.csv', index=False)