File size: 4,075 Bytes
ca54ffd
 
 
 
 
 
 
 
 
 
 
 
 
f6890a5
ca54ffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd8917c
ca54ffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9828c0e
 
 
 
ca54ffd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from src.deepeval.base_task import BaseTask
from collections import defaultdict
from src.deepeval.utils import accuracy, accuracy_standard_error
from typing import Any
import ast


class BiasTask(BaseTask):
    def __init__(self, model_name):
        super().__init__("metunlp/sosyoloji_bias", model_name=model_name)

    def load_dataset_from_hf(self):
        dataset = super().load_dataset_from_hf()
        return dataset


    def evaluate(self) -> dict[str, Any]:
        responses = []
        difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
        total_count = 0
        true = 0

        for row in self.dataset:
            total_count += 2

            # Get values from row

            ## common
            choices = ast.literal_eval(row["choices"])  # Convert string to list
            formatted_choices = "\n".join([f"{chr(65 + i)}: {choice}" for i, choice in enumerate(choices)])

            ## ambiguous context
            ambiguous_context = row["ambiguous_context"]
            ambiguous_question = row["question_ambiguous"]
            ambiguous_answer = row["answer_ambiguous"]
            ambiguous_correct_answer_letter = chr(64 + ambiguous_answer)  # 65 - 1 since we need the indexing to start from 0
            ambiguous_prompt = f"Bağlam: {ambiguous_context}\nSoru: {ambiguous_question}\nSeçenekler:{formatted_choices}"

            ## disambiguated context
            disambiguated_context = row["disambiguated_context"]
            disambiguated_question = row["question_disambiguated"]
            disambiguated_answer = row["answer_disambiguated"]
            disambiguated_correct_answer_letter = chr(64 + disambiguated_answer)  # 65 - 1 since we need the indexing to start from 0
            disambiguated_prompt = f"Bağlam: {disambiguated_context}\nSoru: {disambiguated_question}\nSeçenekler:{formatted_choices}"

            # STAGE 1
            instruction = ""
            prompt = f"Stage1:\n{ambiguous_prompt}\n{instruction}\n"
            message = prompt
            ambiguous_model_answer = self.generate_response_mcqa_multi_token(message, choices=choices, max_new_tokens=2)
            ambiguous_model_answer_cleaned = ambiguous_model_answer.strip().replace('\n', '').replace(' ', '').upper().replace(':','')

            ## Check if correct based on metric
            if ambiguous_correct_answer_letter == ambiguous_model_answer_cleaned:
                true += 1
                difficulty_results["ambiguous"]['correct'] += 1

            difficulty_results["ambiguous"]['total'] += 1

            # STAGE 2
            instruction = ""
            prompt = f"Stage2:\n{disambiguated_prompt}\n{instruction}\n"
            message = prompt
            disambiguated_model_answer = self.generate_response_mcqa_multi_token(message, choices=choices, max_new_tokens=2)
            disambiguated_model_answer_cleaned = disambiguated_model_answer.strip().replace('\n', '').replace(' ','').upper().replace(':', '')
            responses.append((ambiguous_model_answer_cleaned,disambiguated_model_answer_cleaned))

            ## Check if correct based on metric
            if disambiguated_correct_answer_letter == disambiguated_model_answer_cleaned:
                true += 1
                difficulty_results["disambiguated"]['correct'] += 1

            difficulty_results["disambiguated"]['total'] += 1

        # Print results categorized by difficulty
        for category, stats in difficulty_results.items():
            correct = stats['correct']
            total = stats['total']
            calculatedAccuracy = correct / total if total > 0 else 0
            print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({correct}/{total})")

        print("Results:", responses)
        print("Overall Accuracy:", true / total_count)
        acc = accuracy(true, total_count)
        acc_stderr = accuracy_standard_error(acc, total_count)
        return {"acc": acc, "acc_stderr": acc_stderr}