Spaces:
Running
on
L4
Running
on
L4
from src.deepeval.base_task import BaseTask | |
from collections import defaultdict | |
from src.deepeval.utils import accuracy, accuracy_standard_error | |
from typing import Any | |
import re | |
class NERTask(BaseTask): | |
def __init__(self, model_name): | |
super().__init__("metunlp/tr_ner", model_name=model_name) | |
def load_dataset_from_hf(self): | |
dataset = super().load_dataset_from_hf() | |
return dataset | |
def generate_response_oeqa_multi_token(self, msg,max_new_tokens: int = 128): | |
""" | |
Handles multiple-choice questions where answers might have multiple tokens. | |
""" | |
# Ensure tokenizer has proper special tokens set | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
if self.model.config.pad_token_id is None: | |
self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
chat = [ | |
{"role": "user", "content": "You are a question-answering chatbot."}, | |
{"role": "assistant", "content": "I am ready to answer your questions. Feel free to ask anything.\n"}, | |
{"role": "user", "content": f"{msg}"}, | |
] | |
formatted_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
inputs = self.tokenizer(formatted_chat, return_tensors="pt", padding=True, truncation=True) | |
input_ids = inputs.input_ids.to(self.model.device) | |
attention_mask = inputs.attention_mask.to(self.model.device) | |
# Generate response with proper token limits | |
output = self.model.generate( | |
input_ids, | |
do_sample=True, | |
attention_mask=attention_mask, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
temperature=0.4, | |
max_new_tokens=max_new_tokens, | |
) | |
generated_ids = output[0] # The generated sequence including the prompt | |
generated_tokens = generated_ids[len(input_ids[0]):] # Exclude the input_ids part | |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return generated_text | |
def evaluate(self) -> dict[str, Any]: | |
responses = [] | |
difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0}) | |
total_count = 0 | |
true = 0 | |
for row in self.dataset: | |
total_count += 1 | |
# Get values from row | |
category = str(row["difficulty"]) | |
answer = row["final_answer"] | |
question = row["question"] | |
# Construct the prompt/message | |
instruction = ("Aşağıdaki Named Entity Recognition (NER) için etiketlenmesi gereken cümleler vardır. " | |
"Cümlelerdeki varlıkları belirleyin ve şu kategorilere ayırın: CARDINAL, DATE, EVENT, FAC, GPE, LANGUAGE, LAW, LOC, MONEY, NORP, ORDINAL, ORG, PER, PERCENT, PERSON, PRODUCT, QUANTITY, TIME, TITLE, WORK_OF_ART. " | |
"" | |
"Varlıklar, anlamlı bilgiler içeren terimlerdir ve aşağıdaki şekilde tanımlanır: " | |
"CARDINAL: Nicelik veya sıralama belirtmeyen sayısal ifadeler." | |
"DATE: Belirli bir tarih veya zaman ifadeleri." | |
"EVENT: Adlandırılmış olaylar veya durumlar." | |
"FAC: Binalar veya önemli yerler gibi tesisler." | |
"GPE: Ülke, şehir veya eyalet gibi coğrafi-politik varlıklar." | |
"LANGUAGE: Adlandırılmış diller." | |
"LAW: Yasal belgeler, düzenlemeler veya kanunlar." | |
"LOC: Coğrafi veya fiziksel konumlar (GPE dışındaki)." | |
"MONEY: Parasal değerler." | |
"NORP: Milletler, dini veya siyasi gruplar." | |
"ORDINAL: Sıralama veya dereceler." | |
"ORG: Organizasyonlar veya kurumlar." | |
"PER: Kişisel unvanlar veya sıfatlar." | |
"PERSON: Bireylerin isimleri." | |
"PRODUCT: Üretilen nesneler veya araçlar." | |
"QUANTITY: Ölçülebilir miktarlar ve birimler." | |
"TIME: Günün belirli saatleri." | |
"TITLE: Kişi unvanları." | |
"WORK_OF_ART: Sanat eserleri, kitaplar, müzik vb. Adlar, tarih ifadeleri, konumlar gibi belirgin bilgiler varlıktır." | |
"" | |
"Fiiller, sıfatlar, zarflar, soyut kavramlar gibi ifadeler varlık değildir. Çıktıyı aşağıdaki JSON formatında döndürün. " | |
"" | |
"Örnekler: " | |
"Girdi: " | |
"sentence: \"Üç yıl aradan sonra gerçekleştirilen ve Karadeniz, Ege ve Akdeniz’de düzenlenecek olan tatbikata ilişkin Yunanistan'ın Kathimerini gazetesi 'Türk-Yunan: Çetin donanma dengesinin gücü' başlığını kullandı.\"" | |
"Çıktı: " | |
"Üç yıl,DATE" | |
"Karadeniz,LOC" | |
"Ege,LOC" | |
"Akdeniz,LOC" | |
"Yunanistan,GPE" | |
"Kathimerini,ORG" | |
"Türk,NORP" | |
"" | |
"Girdi:" | |
"sentence: \"Evlendikten sonra oyunculuğu bırakan Makal, geçen yıl eşi ve oğluyla beraber İstanbul’dan Göcek’e taşınmıştı." | |
"Çıktı: " | |
"Makal,PERSON" | |
"İstanbul,GPE" | |
"Göcek,GPE" | |
"" | |
"Girdi:" | |
"sentence: \"Yeşil-kırmızılılardan 2016’da ayrılıp 3 sezonluk aradan sonra 2019’da geri dönen Sarıca, takımına 2021 yılında Şampiyonlar Ligi’nde, 2023’te de Süper Lig’de iki final oynattı." | |
"Çıktı:" | |
"2016’da,DATE" | |
"3,CARDINAL" | |
"2019’da,DATE" | |
"Sarıca,PERSON" | |
"2021,DATE" | |
"Şampiyonlar Ligi’nde,EVENT" | |
"2023’te,DATE" | |
"Süper Lig’de,EVENT" | |
"iki,CARDINAL" | |
"" | |
"Verilen cümlelerdeki her varlığı csv formatında yukarıdaki örneklere benzer şekilde belirleyin. Çıktıdaki her satırı aşağıdaki gibi oluşturun: " | |
"<Varlık metni>,<Varlık etiketi>"), | |
prompt = f"{instruction}\n\nSoru:\n{question}\n" | |
message = prompt | |
# Get/format answer of the model | |
model_answer = self.generate_response_oeqa_multi_token(message) | |
responses.append(model_answer) | |
model_answer_cleaned = model_answer | |
# Check if correct based on metric | |
if answer == model_answer_cleaned: | |
true += 1 | |
difficulty_results[category]['correct'] += 1 | |
difficulty_results[category]['total'] += 1 | |
# Print results categorized by difficulty | |
for category, stats in difficulty_results.items(): | |
correct = stats['correct'] | |
total = stats['total'] | |
calculatedAccuracy = correct / total if total > 0 else 0 | |
print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({correct}/{total})") | |
print("Results:", responses) | |
print("Overall Accuracy:", true / total_count) | |
acc = accuracy(true, total_count) | |
acc_stderr = accuracy_standard_error(acc, total_count) | |
return {"acc": acc, "acc_stderr": acc_stderr} | |