Spaces:
Paused
Paused
File size: 7,707 Bytes
ca54ffd f6890a5 ca54ffd 9828c0e ca54ffd 9828c0e ca54ffd 9828c0e ca54ffd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from src.deepeval.base_task import BaseTask
from collections import defaultdict
from src.deepeval.utils import accuracy, accuracy_standard_error
from typing import Any
import re
class NERTask(BaseTask):
def __init__(self, model_name):
super().__init__("metunlp/tr_ner", model_name=model_name)
def load_dataset_from_hf(self):
dataset = super().load_dataset_from_hf()
return dataset
def generate_response_oeqa_multi_token(self, msg,max_new_tokens: int = 128):
"""
Handles multiple-choice questions where answers might have multiple tokens.
"""
# Ensure tokenizer has proper special tokens set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.pad_token_id
chat = [
{"role": "user", "content": "You are a question-answering chatbot."},
{"role": "assistant", "content": "I am ready to answer your questions. Feel free to ask anything.\n"},
{"role": "user", "content": f"{msg}"},
]
formatted_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(formatted_chat, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs.input_ids.to(self.model.device)
attention_mask = inputs.attention_mask.to(self.model.device)
# Generate response with proper token limits
output = self.model.generate(
input_ids,
do_sample=True,
attention_mask=attention_mask,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
temperature=0.4,
max_new_tokens=max_new_tokens,
)
generated_ids = output[0] # The generated sequence including the prompt
generated_tokens = generated_ids[len(input_ids[0]):] # Exclude the input_ids part
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
def evaluate(self) -> dict[str, Any]:
responses = []
difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
total_count = 0
true = 0
for row in self.dataset:
total_count += 1
# Get values from row
category = str(row["difficulty"])
answer = row["final_answer"]
question = row["question"]
# Construct the prompt/message
instruction = ("Aşağıdaki Named Entity Recognition (NER) için etiketlenmesi gereken cümleler vardır. "
"Cümlelerdeki varlıkları belirleyin ve şu kategorilere ayırın: CARDINAL, DATE, EVENT, FAC, GPE, LANGUAGE, LAW, LOC, MONEY, NORP, ORDINAL, ORG, PER, PERCENT, PERSON, PRODUCT, QUANTITY, TIME, TITLE, WORK_OF_ART. "
""
"Varlıklar, anlamlı bilgiler içeren terimlerdir ve aşağıdaki şekilde tanımlanır: "
"CARDINAL: Nicelik veya sıralama belirtmeyen sayısal ifadeler."
"DATE: Belirli bir tarih veya zaman ifadeleri."
"EVENT: Adlandırılmış olaylar veya durumlar."
"FAC: Binalar veya önemli yerler gibi tesisler."
"GPE: Ülke, şehir veya eyalet gibi coğrafi-politik varlıklar."
"LANGUAGE: Adlandırılmış diller."
"LAW: Yasal belgeler, düzenlemeler veya kanunlar."
"LOC: Coğrafi veya fiziksel konumlar (GPE dışındaki)."
"MONEY: Parasal değerler."
"NORP: Milletler, dini veya siyasi gruplar."
"ORDINAL: Sıralama veya dereceler."
"ORG: Organizasyonlar veya kurumlar."
"PER: Kişisel unvanlar veya sıfatlar."
"PERSON: Bireylerin isimleri."
"PRODUCT: Üretilen nesneler veya araçlar."
"QUANTITY: Ölçülebilir miktarlar ve birimler."
"TIME: Günün belirli saatleri."
"TITLE: Kişi unvanları."
"WORK_OF_ART: Sanat eserleri, kitaplar, müzik vb. Adlar, tarih ifadeleri, konumlar gibi belirgin bilgiler varlıktır."
""
"Fiiller, sıfatlar, zarflar, soyut kavramlar gibi ifadeler varlık değildir. Çıktıyı aşağıdaki JSON formatında döndürün. "
""
"Örnekler: "
"Girdi: "
"sentence: \"Üç yıl aradan sonra gerçekleştirilen ve Karadeniz, Ege ve Akdeniz’de düzenlenecek olan tatbikata ilişkin Yunanistan'ın Kathimerini gazetesi 'Türk-Yunan: Çetin donanma dengesinin gücü' başlığını kullandı.\""
"Çıktı: "
"Üç yıl,DATE"
"Karadeniz,LOC"
"Ege,LOC"
"Akdeniz,LOC"
"Yunanistan,GPE"
"Kathimerini,ORG"
"Türk,NORP"
""
"Girdi:"
"sentence: \"Evlendikten sonra oyunculuğu bırakan Makal, geçen yıl eşi ve oğluyla beraber İstanbul’dan Göcek’e taşınmıştı."
"Çıktı: "
"Makal,PERSON"
"İstanbul,GPE"
"Göcek,GPE"
""
"Girdi:"
"sentence: \"Yeşil-kırmızılılardan 2016’da ayrılıp 3 sezonluk aradan sonra 2019’da geri dönen Sarıca, takımına 2021 yılında Şampiyonlar Ligi’nde, 2023’te de Süper Lig’de iki final oynattı."
"Çıktı:"
"2016’da,DATE"
"3,CARDINAL"
"2019’da,DATE"
"Sarıca,PERSON"
"2021,DATE"
"Şampiyonlar Ligi’nde,EVENT"
"2023’te,DATE"
"Süper Lig’de,EVENT"
"iki,CARDINAL"
""
"Verilen cümlelerdeki her varlığı csv formatında yukarıdaki örneklere benzer şekilde belirleyin. Çıktıdaki her satırı aşağıdaki gibi oluşturun: "
"<Varlık metni>,<Varlık etiketi>"),
prompt = f"{instruction}\n\nSoru:\n{question}\n"
message = prompt
# Get/format answer of the model
model_answer = self.generate_response_oeqa_multi_token(message)
responses.append(model_answer)
model_answer_cleaned = model_answer
# Check if correct based on metric
if answer == model_answer_cleaned:
true += 1
difficulty_results[category]['correct'] += 1
difficulty_results[category]['total'] += 1
# Print results categorized by difficulty
for category, stats in difficulty_results.items():
correct = stats['correct']
total = stats['total']
calculatedAccuracy = correct / total if total > 0 else 0
print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({correct}/{total})")
print("Results:", responses)
print("Overall Accuracy:", true / total_count)
acc = accuracy(true, total_count)
acc_stderr = accuracy_standard_error(acc, total_count)
return {"acc": acc, "acc_stderr": acc_stderr}
|