Ahmet Kaan Sever commited on
Commit
76d5f6d
·
1 Parent(s): 597b990

Added Turkish General Knowledge task.

Browse files

Created turkish_general_knowledge_task.py
Added generate_response_mcqa_multi_token
because the original function was not built to handle choices with multiple tokens.
Also created gitignore.

.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .DS_Store
6
+ .env
7
+ .vscode/
8
+ .idea/
9
+ *.log
10
+ node_modules/
src/deepeval/base_task.py CHANGED
@@ -1,4 +1,5 @@
1
  from abc import ABC, abstractmethod
 
2
  from datasets import load_dataset
3
  import os
4
  from dotenv import load_dotenv
@@ -71,6 +72,53 @@ class BaseTask(ABC):
71
  answer = self.tokenizer.decode(output[0][-1])
72
 
73
  return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  @abstractmethod
76
  def load_dataset_from_hf(self):
 
1
  from abc import ABC, abstractmethod
2
+ import itertools
3
  from datasets import load_dataset
4
  import os
5
  from dotenv import load_dotenv
 
72
  answer = self.tokenizer.decode(output[0][-1])
73
 
74
  return answer
75
+
76
+ def generate_response_mcqa_multi_token(self, msg, max_new_tokens=5, choices: list = []):
77
+ """
78
+ Handles multiple-choice questions where answers might have multiple tokens.
79
+ """
80
+ # Ensure the tokenizer has a padding token
81
+ if self.tokenizer.pad_token is None:
82
+ self.tokenizer.pad_token = self.tokenizer.eos_token # Use EOS token as PAD token
83
+
84
+ inputs = self.tokenizer(msg, return_tensors="pt", padding=True, truncation=True)
85
+ input_ids = inputs.input_ids.to(self.model.device)
86
+ attention_mask = inputs.attention_mask.to(self.model.device)
87
+
88
+ if self.model.config.pad_token_id is None:
89
+ self.model.config.pad_token_id = self.tokenizer.eos_token_id
90
+
91
+ # Tokenize multi-token choices (do not flatten)
92
+ valid_token_ids = [self.tokenizer.encode(ans, add_special_tokens=False) for ans in choices]
93
+ print("Valid token IDs:", valid_token_ids)
94
+
95
+ class MultipleChoiceLogitsProcessor:
96
+ def __init__(self, valid_token_ids):
97
+ self.valid_token_ids = valid_token_ids # List of tokenized choices
98
+
99
+ def __call__(self, input_ids, scores):
100
+ mask = torch.full_like(scores, float("-inf")) # Mask everything by default
101
+
102
+ # Allow the tokens in choices
103
+ allowed_tokens = {token for tokens in self.valid_token_ids for token in tokens}
104
+ mask[:, list(allowed_tokens)] = scores[:, list(allowed_tokens)] # Allow only these tokens
105
+
106
+ return mask
107
+
108
+ logits_processor = LogitsProcessorList([MultipleChoiceLogitsProcessor(valid_token_ids)])
109
+
110
+ output = self.model.generate(
111
+ input_ids,
112
+ attention_mask=attention_mask,
113
+ max_new_tokens=max_new_tokens,
114
+ logits_processor=logits_processor
115
+ )
116
+
117
+ # Decode and compare with choices to find the best match
118
+ generated_text = self.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
119
+ best_match = max(choices, key=lambda choice: generated_text.startswith(choice)) # Pick closest match
120
+
121
+ return best_match
122
 
123
  @abstractmethod
124
  def load_dataset_from_hf(self):
src/deepeval/deepeval_task_manager.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from dotenv import load_dotenv
3
  from enum import Enum
 
4
  from src.deepeval.sentiment_analysis_task import SentimentAnalysisTask
5
  from typing import List
6
  load_dotenv()
@@ -10,6 +11,7 @@ HF_TOKEN=os.getenv("HF_TOKEN")
10
  class Task(Enum):
11
  # SUMMARIZATION = "summarization"
12
  SENTIMENT_ANALYSIS = "sentiment_analysis_tr"
 
13
 
14
 
15
  class DeepEvalTaskManager:
@@ -21,6 +23,7 @@ class DeepEvalTaskManager:
21
  def validate_tasks(self, user_tasks):
22
  """Validate user tasks and store method references."""
23
  print(self.available_tasks.keys())
 
24
  if not set(user_tasks).issubset(self.available_tasks.keys()):
25
  invalid_tasks = set(user_tasks) - self.available_tasks.keys()
26
  raise ValueError(f"Invalid task(s) requested: {invalid_tasks}")
@@ -42,9 +45,14 @@ class DeepEvalTaskManager:
42
  st_task = SentimentAnalysisTask(self.model_name)
43
  res = st_task.evaluate()
44
  return res
 
 
 
 
 
45
 
46
 
47
  if __name__ == "__main__":
48
- des = DeepEvalTaskManager("meta-llama/Llama-3.2-1B-Instruct", ["SENTIMENT_ANALYSIS"])
49
  res = des.run_tasks()
50
  print(res)
 
1
  import os
2
  from dotenv import load_dotenv
3
  from enum import Enum
4
+ from src.deepeval.turkish_general_knowledge_task import TurkishGeneralKnowledgeTask
5
  from src.deepeval.sentiment_analysis_task import SentimentAnalysisTask
6
  from typing import List
7
  load_dotenv()
 
11
  class Task(Enum):
12
  # SUMMARIZATION = "summarization"
13
  SENTIMENT_ANALYSIS = "sentiment_analysis_tr"
14
+ TURKISH_GENERAL_KNOWLEDGE = "turkish_general_knowledge"
15
 
16
 
17
  class DeepEvalTaskManager:
 
23
  def validate_tasks(self, user_tasks):
24
  """Validate user tasks and store method references."""
25
  print(self.available_tasks.keys())
26
+ print(user_tasks)
27
  if not set(user_tasks).issubset(self.available_tasks.keys()):
28
  invalid_tasks = set(user_tasks) - self.available_tasks.keys()
29
  raise ValueError(f"Invalid task(s) requested: {invalid_tasks}")
 
45
  st_task = SentimentAnalysisTask(self.model_name)
46
  res = st_task.evaluate()
47
  return res
48
+
49
+ def turkish_general_knowledge(self):
50
+ turkish_general_knowledge_task = TurkishGeneralKnowledgeTask(self.model_name)
51
+ res = turkish_general_knowledge_task.evaluate()
52
+ return res
53
 
54
 
55
  if __name__ == "__main__":
56
+ des = DeepEvalTaskManager("meta-llama/Llama-3.2-1B-Instruct", ["TURKISH_GENERAL_KNOWLEDGE"])
57
  res = des.run_tasks()
58
  print(res)
src/deepeval/turkish_general_knowledge_task.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.deepeval.base_task import BaseTask
2
+ from collections import defaultdict
3
+ import ast
4
+
5
+ class TurkishGeneralKnowledgeTask(BaseTask):
6
+ def __init__(self, model_name):
7
+ super().__init__("metunlp/turkish_general_knowledge", model_name=model_name)
8
+
9
+ def load_dataset_from_hf(self):
10
+ dataset = super().load_dataset_from_hf()
11
+ return dataset.select(range(min(10, len(dataset))))
12
+
13
+ def evaluate(self):
14
+ responses = []
15
+ difficulty_results = defaultdict(lambda: {'correct': 0, 'total': 0})
16
+ total_count = 0
17
+ true = 0
18
+
19
+ for row in self.dataset:
20
+ total_count += 1
21
+ question = row["question"]
22
+ choices = ast.literal_eval(row["choices"]) # Convert string to list
23
+ answer_index = row["answer"] # Assuming it's zero-based index
24
+ difficulty = row["difficulty"]
25
+
26
+ print(f"Choices: {choices}")
27
+ print("Type of choices:", type(choices))
28
+ # Categorize difficulty
29
+ if difficulty <= 3:
30
+ category = 'easy'
31
+ elif 3 < difficulty <= 6:
32
+ category = 'medium'
33
+ else:
34
+ category = 'hard'
35
+
36
+ # Create a multiple-choice prompt to encourage index output
37
+ formatted_choices = "\n".join([f"{i}: {choice}" for i, choice in enumerate(choices)])
38
+ prompt = f"Soru: {question}\nSeçenekler:\n{formatted_choices}\nSorunun doğru cevabı hangisidir?"
39
+
40
+ print(f"Prompt: {prompt}")
41
+ model_answer = self.generate_response_mcqa_multi_token(prompt, choices=choices, max_new_tokens=30)
42
+ responses.append(model_answer)
43
+ print(f"Correct Answer: {choices[answer_index]}")
44
+ print(f"Model Answer: {model_answer}")
45
+ # Check if the answer is correct
46
+ if choices[answer_index] == model_answer:
47
+ true += 1
48
+ difficulty_results[category]['correct'] += 1
49
+
50
+ difficulty_results[category]['total'] += 1
51
+
52
+ # Print results categorized by difficulty
53
+ for category, stats in difficulty_results.items():
54
+ accuracy = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
55
+ print(f"{category.capitalize()} Accuracy: {accuracy:.2%} ({stats['correct']}/{stats['total']})")
56
+
57
+ print("Results:", responses)
58
+ print("Overall Accuracy:", true / total_count)
59
+ return true / total_count