File size: 3,423 Bytes
33d2454
 
 
 
 
 
 
 
 
 
 
 
dbf76bc
33d2454
 
 
 
 
 
 
 
 
 
 
 
a43bda3
 
 
33d2454
 
 
 
 
 
 
 
 
 
8a3d32e
 
 
33d2454
 
 
 
a43bda3
 
33d2454
 
 
1657c25
33d2454
 
 
 
8a3d32e
 
 
33d2454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from src.deepeval.base_task import BaseTask
from collections import defaultdict
from src.deepeval.utils import accuracy, accuracy_standard_error
from typing import Any


class NLITask(BaseTask):
    def __init__(self, model_name):
        super().__init__("metunlp/nli_tr", model_name=model_name)

    def load_dataset_from_hf(self):
        dataset = super().load_dataset_from_hf()
        return dataset


    def evaluate(self) -> dict[str, Any]:
        responses = []
        difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
        total_count = 0
        true = 0

        for row in self.dataset:
            total_count += 1

            # Get values from row
            text = row["text"]
            premise = row["premise"]
            hypothesis = row["hypothesis"]
            label = row["label"].lower().replace(' ','')
            choices=["entailment","contradiction","neutral"]
            formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
            category = row["difficulty"]
            correct_answer_letter = "A" if label == "entailment" else \
                                    "B" if label == "contradiction" else \
                                    "C" if label == "neutral" else None


            # Prints for debugging
            # print(f"Choices: {choices}")
            # print("Type of choices:", type(choices))
            # print("Label:", label)

            # Construct the prompt/message
            instruction = ""
            question = "Yukarıdaki cümleler arasındaki ilişki “entailment” (bir cümle diğerini ima eder), “neutral (cümleler birbirini ima etmez ve çelişmez) veya “contradiction (cümleler birbirleriyle çelişir) olarak karakterize edilebilir. Bu ilişkilerden hangisi olduğunu söyleyin."
            context = f"Bağlam:\n{text}\n" # can add to prompt if needed
            prompt = f"Cümle1:\n{premise}\nCümle2:{hypothesis}\nSoru:\n{question}\nSeçenekler:\n{formatted_choices}\n{instruction}\n"
            message = prompt

            # Get/format answer of the model
            model_answer = self.generate_response_mcqa_multi_token(message, choices=choices, max_new_tokens=2)
            responses.append(model_answer)
            model_answer_cleaned = model_answer.strip().replace('\n', '').replace(' ', '').upper()

            # Print answers
            # print(f"Correct Answer: {correct_answer_letter}")
            # print(f"Model Answer: {model_answer}")
            # print(f"Model Answer Cleaned: {model_answer_cleaned}")

            # Check if correct based on metric
            if correct_answer_letter == model_answer_cleaned:
                true += 1
                difficulty_results[category]['correct'] += 1

            difficulty_results[category]['total'] += 1

        # Print results categorized by difficulty
        for category, stats in difficulty_results.items():
            calculatedAccuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
            print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({stats['correct']}/{stats['total']})")

        print("Results:", responses)
        print("Overall Accuracy:", true / total_count)
        acc = accuracy(true, total_count)
        acc_stderr = accuracy_standard_error(acc, total_count)
        return {"acc": acc, "acc_stderr": acc_stderr}