aacengiz commited on
Commit
17e229f
·
1 Parent(s): 7b3d3a5

update commonsense reasoning

Browse files
src/deepeval/commonsense_reasoning_task.py CHANGED
@@ -1,27 +1,47 @@
1
  from src.deepeval.base_task import BaseTask
 
2
  from src.deepeval.utils import accuracy, accuracy_standard_error
3
  from typing import Any
4
 
 
5
  class CommonsenseReasoningTask(BaseTask):
6
  def __init__(self, model_name):
7
  super().__init__("metunlp/commonsense", model_name=model_name)
8
 
9
  def load_dataset_from_hf(self):
10
- print("Loading the dataset")
11
  dataset = super().load_dataset_from_hf()
12
- return dataset.select(range(min(1, len(dataset))))
13
 
14
 
15
  def evaluate(self) -> dict[str, Any]:
16
  responses = []
17
- total_count = len(self.dataset)
18
- n_correct = 0
 
 
19
  for row in self.dataset:
20
- print(row)
 
 
21
  label = row["label"]
22
  choices=[row["choice1"], row["choice2"]]
23
  formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
25
  if label == "effect":
26
  question = "Seçeneklerden hangisi verilen önermenin bir sonucu veya etkisi olabilir?"
27
  elif label == "cause":
@@ -29,21 +49,36 @@ class CommonsenseReasoningTask(BaseTask):
29
  else:
30
  question = "Seçeneklerden hangisi uygun?" # Alternatif
31
 
32
- prompt = f"Bağlam:\n{row["text"]}\nÖnerme:\n{row["context"]}\nSoru:{question}\nSeçenekler:\n{formatted_choices}"
33
-
34
- messages = prompt
 
35
 
36
- model_answer = self.generate_response_mcqa_multi_token(messages, choices=choices)
37
-
38
- correct_answer_letter = "A" if row["answer"] == 1 else "B" if row["answer"] == 2 else None
39
  model_answer_cleaned = model_answer.strip().replace('\n', '').replace(' ', '').upper()
40
- if correct_answer_letter == model_answer_cleaned:
41
- n_correct += 1
42
  print(f"Correct Answer: {correct_answer_letter}")
43
  print(f"Model Answer: {model_answer}")
44
  print(f"Model Answer Cleaned: {model_answer_cleaned}")
45
 
46
- acc = accuracy(n_correct, total_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  acc_stderr = accuracy_standard_error(acc, total_count)
48
  return {"acc": acc, "acc_stderr": acc_stderr}
49
 
 
1
  from src.deepeval.base_task import BaseTask
2
+ from collections import defaultdict
3
  from src.deepeval.utils import accuracy, accuracy_standard_error
4
  from typing import Any
5
 
6
+
7
  class CommonsenseReasoningTask(BaseTask):
8
  def __init__(self, model_name):
9
  super().__init__("metunlp/commonsense", model_name=model_name)
10
 
11
  def load_dataset_from_hf(self):
 
12
  dataset = super().load_dataset_from_hf()
13
+ return dataset.select(range(min(2, len(dataset))))
14
 
15
 
16
  def evaluate(self) -> dict[str, Any]:
17
  responses = []
18
+ difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
19
+ total_count = 0
20
+ true = 0
21
+
22
  for row in self.dataset:
23
+ total_count += 1
24
+
25
+ # Get values from row
26
  label = row["label"]
27
  choices=[row["choice1"], row["choice2"]]
28
  formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
29
+ category = row["difficulty"]
30
+ answer = row["answer"]
31
+
32
+ # Prints for debugging
33
+ print(f"Choices: {choices}")
34
+ print("Type of choices:", type(choices))
35
+ print("Type of answer:", type(answer))
36
 
37
+ # Get answer index (starting from 0)
38
+ if type(answer) == int:
39
+ answer_index = answer - 1 # 1 or 2
40
+ else:
41
+ answer_index = int(answer) - 1
42
+ correct_answer_letter = chr(65 + answer_index)
43
+
44
+ # Get question based on label
45
  if label == "effect":
46
  question = "Seçeneklerden hangisi verilen önermenin bir sonucu veya etkisi olabilir?"
47
  elif label == "cause":
 
49
  else:
50
  question = "Seçeneklerden hangisi uygun?" # Alternatif
51
 
52
+ # Construct the prompt/message
53
+ instruction = ""
54
+ prompt = f"Bağlam:\n{row["text"]}\nÖnerme:\n{row["context"]}\nSoru:{question}\nSeçenekler:\n{formatted_choices}\n{instruction}\n"
55
+ message = prompt
56
 
57
+ # Get/format answer of the model
58
+ model_answer = self.generate_response_mcqa_multi_token(message, choices=choices, max_new_tokens=10)
59
+ responses.append(model_answer)
60
  model_answer_cleaned = model_answer.strip().replace('\n', '').replace(' ', '').upper()
61
+
62
+ # Print answers
63
  print(f"Correct Answer: {correct_answer_letter}")
64
  print(f"Model Answer: {model_answer}")
65
  print(f"Model Answer Cleaned: {model_answer_cleaned}")
66
 
67
+ # Check if correct based on metric
68
+ if correct_answer_letter == model_answer_cleaned:
69
+ true += 1
70
+ difficulty_results[category]['correct'] += 1
71
+
72
+ difficulty_results[category]['total'] += 1
73
+
74
+ # Print results categorized by difficulty
75
+ for category, stats in difficulty_results.items():
76
+ calculatedAccuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
77
+ print(f"{category.capitalize()} Accuracy: {calculatedAccuracy:.2%} ({stats['correct']}/{stats['total']})")
78
+
79
+ print("Results:", responses)
80
+ print("Overall Accuracy:", true / total_count)
81
+ acc = accuracy(true, total_count)
82
  acc_stderr = accuracy_standard_error(acc, total_count)
83
  return {"acc": acc, "acc_stderr": acc_stderr}
84