Spaces:
Running
on
L4
Running
on
L4
from src.deepeval.base_task import BaseTask | |
from collections import defaultdict | |
from src.deepeval.utils import accuracy, accuracy_standard_error | |
from typing import Any | |
import re | |
from datasets import load_dataset | |
import os | |
from dotenv import load_dotenv | |
import openai | |
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor | |
import torch | |
from typing import List | |
class STSTask(BaseTask): | |
def __init__(self, model_name): | |
super().__init__("metunlp/sts_tr", model_name=model_name) | |
def load_dataset_from_hf(self): | |
dataset = super().load_dataset_from_hf() | |
return dataset | |
def generate_response_sts_multi_token(self, msg, max_new_tokens=5, choices: list = []): | |
""" | |
Handles multiple-choice questions where answers might have multiple tokens. | |
""" | |
# Ensure tokenizer has proper special tokens set | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
if self.model.config.pad_token_id is None: | |
self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
chat = [ | |
{"role": "user", | |
"content": "You are a sentence similarity scoring chatbot. Only respond with one of the given scores: 0, 1, 2, 3, 4, or 5."}, | |
{"role": "assistant", "content": "I am ready to answer your questions. Feel free to ask anything.\n"}, | |
{"role": "user", "content": f"{msg}"}, | |
] | |
formatted_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
inputs = self.tokenizer(formatted_chat, return_tensors="pt", padding=True, truncation=True) | |
input_ids = inputs.input_ids.to(self.model.device) | |
attention_mask = inputs.attention_mask.to(self.model.device) | |
# Generate the sequence of letters starting from 'A' | |
letters = ["0","1","2","3","4","5"] | |
encoded_choices = [self.tokenizer.encode(letter, add_special_tokens=False) for letter in letters] | |
flattened_encoded_choices = [item for sublist in encoded_choices for item in sublist] # Flatten the list | |
allowed_tokens = flattened_encoded_choices | |
allowed_tokens += self.get_chat_template_tokens() # Get the special chat tokens | |
allowed_token_ids = set(allowed_tokens) # Ensure uniqueness | |
# Custom LogitsProcessor to restrict generation | |
class RestrictToABCDLogitsProcessor(LogitsProcessor): | |
def __call__(self, input_ids, scores): | |
mask = torch.full_like(scores, float("-inf")) # Block all tokens | |
mask[:, list(allowed_token_ids)] = scores[:, list(allowed_token_ids)] # Allow only A, B, C, D tokens | |
return mask | |
logits_processor = LogitsProcessorList([RestrictToABCDLogitsProcessor()]) | |
# Generate response | |
output = self.model.generate( | |
input_ids, | |
do_sample=True, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
temperature=0.4, | |
logits_processor=logits_processor, | |
) | |
generated_ids = output[0] # The generated sequence including the prompt | |
generated_tokens = generated_ids[len(input_ids[0]):] # Exclude the input_ids part | |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return generated_text | |
def evaluate(self) -> dict[str, Any]: | |
responses = [] | |
difficulty_results = {'correct': 0, 'total': 0} | |
total_count = 0 | |
true = 0 | |
for row in self.dataset: | |
total_count += 1 | |
# Get values from row | |
answer = row["score"] | |
choices = ["0","1","2","3","4","5"] | |
sentence_1 = row["sentence_1"] | |
sentence_2 = row["sentence_2"] | |
# Construct the prompt/message | |
instruction = f"Aşağıda verilen iki cümlenin birbirlerine olan anlamsal benzerliğini 0'dan 5'e kadar olan bir tam sayıyla söyleyin." | |
prompt = f"""{instruction}\nCümle 1: {sentence_1}\nCümle 2: {sentence_2}\nSadece tek bir tam sayı söyleyin, ek bir kelime ya da sembol kullanmayın.""" | |
message = prompt | |
# Get/format answer of the model | |
model_answer = self.generate_response_sts_multi_token(message, max_new_tokens=2) | |
responses.append(model_answer) | |
model_answer_cleaned = model_answer.strip().replace('\n', '').replace(' ', '').upper().replace(':','') | |
# Check if correct based on metric | |
if answer == model_answer_cleaned: | |
true += 1 | |
difficulty_results['correct'] += 1 | |
difficulty_results['total'] += 1 | |
# Print results | |
stats = difficulty_results | |
correct = stats['correct'] | |
total = stats['total'] | |
calculatedAccuracy = correct / total if total > 0 else 0 | |
print(f"Accuracy: {calculatedAccuracy:.2%} ({correct}/{total})") | |
print("Results:", responses) | |
print("Overall Accuracy:", true / total_count) | |
acc = accuracy(true, total_count) | |
acc_stderr = accuracy_standard_error(acc, total_count) | |
return {"acc": acc, "acc_stderr": acc_stderr} | |