aacengiz commited on
Commit
847b372
·
1 Parent(s): 48b440e

add commonsense reasoning

Browse files
src/deepeval/commonsense_reasoning_task.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.deepeval.base_task import BaseTask
2
+ from src.deepeval.utils import accuracy, accuracy_standard_error
3
+ from typing import Any
4
+
5
+ class SentimentAnalysisTask(BaseTask):
6
+ def __init__(self, model_name):
7
+ super().__init__("metunlp/commonsense", model_name=model_name)
8
+
9
+ def load_dataset_from_hf(self):
10
+ dataset = super().load_dataset_from_hf()
11
+ return dataset.select(range(min(10, len(dataset))))
12
+
13
+
14
+ def evaluate(self) -> dict[str, Any]:
15
+ responses = []
16
+ total_count = len(self.dataset)
17
+ n_correct = 0
18
+ for row in self.dataset:
19
+ sentence = row["sentence"]
20
+ label = row["label"]
21
+ choices=[row["choice1"], row["choice2"]]
22
+ formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
23
+
24
+ if label == "effect":
25
+ question = "Seçeneklerden hangisi verilen önermenin bir sonucu veya etkisi olabilir?"
26
+ elif label == "cause":
27
+ question = "Seçeneklerden hangisi verilen önermenin bir neden veya sebebi olabilir?"
28
+ else:
29
+ question = "Seçeneklerden hangisi uygun?" # Alternatif
30
+
31
+ prompt = f"Premise:\n{line["text"]}\nSoru:{question}\nSeçenekler:\n{formatted_choices}"
32
+
33
+ messages = prompt
34
+
35
+ answer = self.generate_response_mcqa_multi_token(messages, choices=choices)
36
+ print("Answer:", answer)
37
+ responses.append(answer)
38
+ correct_answer_letter = "A" if row["sentiment"] == "positive" else "B" if row["sentiment"] == "negative" else "C" if row["sentiment"] == "neutral" else None
39
+ model_answer_cleaned = answer.strip().replace('\n', '').replace(' ', '').upper()
40
+ if correct_answer_letter == model_answer_cleaned:
41
+ n_correct += 1
42
+
43
+ acc = accuracy(n_correct, total_count)
44
+ acc_stderr = accuracy_standard_error(acc, total_count)
45
+ return {"acc": acc, "acc_stderr": acc_stderr}
46
+