Ahmet Kaan Sever
Removed unnecessary debug prints and timestamps now return seconds.
8a3d32e
from src.deepeval.base_task import BaseTask
from collections import defaultdict
from src.deepeval.utils import accuracy, accuracy_standard_error
from typing import Any
class NLITask(BaseTask):
def __init__(self, model_name):
super().__init__("metunlp/nli_tr", model_name=model_name)
def load_dataset_from_hf(self):
dataset = super().load_dataset_from_hf()
return dataset
def evaluate(self) -> dict[str, Any]:
responses = []
difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
total_count = 0
true = 0
for row in self.dataset:
total_count += 1
# Get values from row
text = row["text"]
premise = row["premise"]
hypothesis = row["hypothesis"]
label = row["label"].lower().replace(' ','')
choices=["entailment","contradiction","neutral"]
formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
category = row["difficulty"]
correct_answer_letter = "A" if label == "entailment" else \
"B" if label == "contradiction" else \
"C" if label == "neutral" else None
# Prints for debugging
# print(f"Choices: {choices}")
# print("Type of choices:", type(choices))
# print("Label:", label)
# Construct the prompt/message
instruction = ""
question = "Yukarıdaki cümleler arasındaki ilişki “entailment” (bir cümle diğerini ima eder), “neutral (cümleler birbirini ima etmez ve çelişmez) veya “contradiction (cümleler birbirleriyle çelişir) olarak karakterize edilebilir. Bu ilişkilerden hangisi olduğunu söyleyin."
context = f"Bağlam:\n{text}\n" # can add to prompt if needed
prompt = f"Cümle1:\n{premise}\nCümle2:{hypothesis}\nSoru:\n{question}\nSeçenekler:\n{formatted_choices}\n{instruction}\n"
message = prompt
# Get/format answer of the model
model_answer = self.generate_response_mcqa_multi_token(message, choices=choices, max_new_tokens=2)
responses.append(model_answer)
model_answer_cleaned = model_answer.strip().replace('\n', '').replace(' ', '').upper()
# Print answers
# print(f"Correct Answer: {correct_answer_letter}")
# print(f"Model Answer: {model_answer}")
# print(f"Model Answer Cleaned: {model_answer_cleaned}")
# Check if correct based on metric
if correct_answer_letter == model_answer_cleaned:
true += 1
difficulty_results[category]['correct'] += 1
difficulty_results[category]['total'] += 1
# Print results categorized by difficulty
for category, stats in difficulty_results.items():
calculatedAccuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({stats['correct']}/{stats['total']})")
print("Results:", responses)
print("Overall Accuracy:", true / total_count)
acc = accuracy(true, total_count)
acc_stderr = accuracy_standard_error(acc, total_count)
return {"acc": acc, "acc_stderr": acc_stderr}