File size: 3,483 Bytes
847b372
17e229f
847b372
 
 
17e229f
7b3d3a5
847b372
 
 
 
 
dbf76bc
847b372
 
 
 
17e229f
 
 
 
847b372
17e229f
 
 
847b372
 
 
17e229f
 
d866f01
 
17e229f
 
8a3d32e
 
 
847b372
17e229f
 
 
 
 
 
 
 
847b372
 
 
 
 
 
 
17e229f
 
d866f01
17e229f
7b3d3a5
17e229f
1657c25
17e229f
7b3d3a5
17e229f
 
8a3d32e
 
 
847b372
17e229f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847b372
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from src.deepeval.base_task import BaseTask
from collections import defaultdict
from src.deepeval.utils import accuracy, accuracy_standard_error
from typing import Any


class CommonsenseReasoningTask(BaseTask):
    def __init__(self, model_name):
        super().__init__("metunlp/commonsense", model_name=model_name)

    def load_dataset_from_hf(self):
        dataset = super().load_dataset_from_hf()
        return dataset


    def evaluate(self) -> dict[str, Any]:
        responses = []
        difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
        total_count = 0
        true = 0

        for row in self.dataset:
            total_count += 1

            # Get values from row
            label = row["label"]
            choices=[row["choice1"], row["choice2"]]
            formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
            category = row["difficulty"]
            answer = row["answer"]
            text = row["text"]
            context = row["context"]

            # Prints for debugging
            # print(f"Choices: {choices}")
            # print("Type of choices:", type(choices))
            # print("Type of answer:", type(answer))

            # Get answer index (starting from 0)
            if type(answer) == int:
                answer_index = answer - 1 # 1 or 2
            else:
                answer_index = int(answer) - 1
            correct_answer_letter = chr(65 + answer_index)

            # Get question based on label
            if label == "effect":
                question = "Seçeneklerden hangisi verilen önermenin bir sonucu veya etkisi olabilir?"
            elif label == "cause":
                question = "Seçeneklerden hangisi verilen önermenin bir neden veya sebebi olabilir?"
            else:
                question = "Seçeneklerden hangisi uygun?"  # Alternatif

            # Construct the prompt/message
            instruction = ""
            prompt = f"Bağlam:\n{text}\nÖnerme:\n{context}\nSoru:{question}\nSeçenekler:\n{formatted_choices}\n{instruction}\n"
            message = prompt

            # Get/format answer of the model
            model_answer = self.generate_response_mcqa_multi_token(message, choices=choices, max_new_tokens=2)
            responses.append(model_answer)
            model_answer_cleaned = model_answer.strip().replace('\n', '').replace(' ', '').upper()

            # Print answers
            # print(f"Correct Answer: {correct_answer_letter}")
            # print(f"Model Answer: {model_answer}")
            # print(f"Model Answer Cleaned: {model_answer_cleaned}")

            # Check if correct based on metric
            if correct_answer_letter == model_answer_cleaned:
                true += 1
                difficulty_results[category]['correct'] += 1

            difficulty_results[category]['total'] += 1

        # Print results categorized by difficulty
        for category, stats in difficulty_results.items():
            calculatedAccuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
            print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({stats['correct']}/{stats['total']})")

        print("Results:", responses)
        print("Overall Accuracy:", true / total_count)
        acc = accuracy(true, total_count)
        acc_stderr = accuracy_standard_error(acc, total_count)
        return {"acc": acc, "acc_stderr": acc_stderr}