Ahmet Kaan Sever commited on
Commit
b9fba6d
·
1 Parent(s): bb84099

Re added function that got lost with commit

Browse files
Files changed (1) hide show
  1. src/deepeval/base_task.py +56 -0
src/deepeval/base_task.py CHANGED
@@ -71,6 +71,62 @@ class BaseTask(ABC):
71
  answer = self.tokenizer.decode(output[0][-1])
72
 
73
  return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def generate_response(self, prompt: str, max_new_tokens: int = 100) -> str:
76
 
 
71
  answer = self.tokenizer.decode(output[0][-1])
72
 
73
  return answer
74
+
75
+ def generate_response_mcqa_multi_token(self, msg, max_new_tokens=5, choices: list = []):
76
+ """
77
+ Handles multiple-choice questions where answers might have multiple tokens.
78
+ """
79
+ # Ensure tokenizer has proper special tokens set
80
+ if self.tokenizer.pad_token is None:
81
+ self.tokenizer.pad_token = self.tokenizer.eos_token
82
+
83
+ if self.model.config.pad_token_id is None:
84
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
85
+
86
+ chat = [
87
+ {"role": "user", "content": "You are a multiple choice question-answering chatbot. Do not give an answer that is not included in the choices. Only answer with letters like A, B, C, D..."},
88
+ {"role": "assistant", "content": "I am ready to answer your questions. Feel free to ask anything.\n"},
89
+ {"role": "user", "content": f"{msg}"},
90
+ ]
91
+ formatted_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
92
+ print(formatted_chat)
93
+ inputs = self.tokenizer(formatted_chat, return_tensors="pt", padding=True, truncation=True)
94
+ input_ids = inputs.input_ids.to(self.model.device)
95
+ attention_mask = inputs.attention_mask.to(self.model.device)
96
+
97
+ # Generate the sequence of letters starting from 'A'
98
+ letters = [chr(ord('A') + i) for i in range(len(choices))] # Create option letters A, B, C, D, E, ...
99
+ encoded_choices = [self.tokenizer.encode(letter, add_special_tokens=False) for letter in letters]
100
+ flattened_encoded_choices = [item for sublist in encoded_choices for item in sublist] # Flatten the list
101
+ print(flattened_encoded_choices)
102
+
103
+ allowed_tokens = flattened_encoded_choices
104
+ allowed_tokens += self.get_chat_template_tokens() # Get the special chat tokens
105
+ allowed_token_ids = set(allowed_tokens) # Ensure uniqueness
106
+
107
+ # Custom LogitsProcessor to restrict generation
108
+ class RestrictToABCDLogitsProcessor(LogitsProcessor):
109
+ def __call__(self, input_ids, scores):
110
+ mask = torch.full_like(scores, float("-inf")) # Block all tokens
111
+ mask[:, list(allowed_token_ids)] = scores[:, list(allowed_token_ids)] # Allow only A, B, C, D tokens
112
+ return mask
113
+ logits_processor = LogitsProcessorList([RestrictToABCDLogitsProcessor()])
114
+
115
+ # Generate response
116
+ output = self.model.generate(
117
+ input_ids,
118
+ do_sample=True,
119
+ attention_mask=attention_mask,
120
+ max_new_tokens=max_new_tokens,
121
+ eos_token_id=self.tokenizer.eos_token_id,
122
+ pad_token_id=self.tokenizer.pad_token_id,
123
+ temperature=0.4,
124
+ logits_processor=logits_processor,
125
+ )
126
+ generated_ids = output[0] # The generated sequence including the prompt
127
+ generated_tokens = generated_ids[len(input_ids[0]):] # Exclude the input_ids part
128
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
129
+ return generated_text
130
 
131
  def generate_response(self, prompt: str, max_new_tokens: int = 100) -> str:
132