Spaces:
Sleeping
Sleeping
add commonsense reasoning
Browse files
src/deepeval/commonsense_reasoning_task.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.deepeval.base_task import BaseTask
|
2 |
+
from src.deepeval.utils import accuracy, accuracy_standard_error
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
class SentimentAnalysisTask(BaseTask):
|
6 |
+
def __init__(self, model_name):
|
7 |
+
super().__init__("metunlp/commonsense", model_name=model_name)
|
8 |
+
|
9 |
+
def load_dataset_from_hf(self):
|
10 |
+
dataset = super().load_dataset_from_hf()
|
11 |
+
return dataset.select(range(min(10, len(dataset))))
|
12 |
+
|
13 |
+
|
14 |
+
def evaluate(self) -> dict[str, Any]:
|
15 |
+
responses = []
|
16 |
+
total_count = len(self.dataset)
|
17 |
+
n_correct = 0
|
18 |
+
for row in self.dataset:
|
19 |
+
sentence = row["sentence"]
|
20 |
+
label = row["label"]
|
21 |
+
choices=[row["choice1"], row["choice2"]]
|
22 |
+
formatted_choices = "\n".join([f"{chr(65+i)}: {choice}" for i, choice in enumerate(choices)])
|
23 |
+
|
24 |
+
if label == "effect":
|
25 |
+
question = "Seçeneklerden hangisi verilen önermenin bir sonucu veya etkisi olabilir?"
|
26 |
+
elif label == "cause":
|
27 |
+
question = "Seçeneklerden hangisi verilen önermenin bir neden veya sebebi olabilir?"
|
28 |
+
else:
|
29 |
+
question = "Seçeneklerden hangisi uygun?" # Alternatif
|
30 |
+
|
31 |
+
prompt = f"Premise:\n{line["text"]}\nSoru:{question}\nSeçenekler:\n{formatted_choices}"
|
32 |
+
|
33 |
+
messages = prompt
|
34 |
+
|
35 |
+
answer = self.generate_response_mcqa_multi_token(messages, choices=choices)
|
36 |
+
print("Answer:", answer)
|
37 |
+
responses.append(answer)
|
38 |
+
correct_answer_letter = "A" if row["sentiment"] == "positive" else "B" if row["sentiment"] == "negative" else "C" if row["sentiment"] == "neutral" else None
|
39 |
+
model_answer_cleaned = answer.strip().replace('\n', '').replace(' ', '').upper()
|
40 |
+
if correct_answer_letter == model_answer_cleaned:
|
41 |
+
n_correct += 1
|
42 |
+
|
43 |
+
acc = accuracy(n_correct, total_count)
|
44 |
+
acc_stderr = accuracy_standard_error(acc, total_count)
|
45 |
+
return {"acc": acc, "acc_stderr": acc_stderr}
|
46 |
+
|