File size: 11,530 Bytes
cec00dd
 
 
 
6807ea3
b30c279
7a6ddbf
cec00dd
 
66a11b3
cec00dd
 
6807ea3
cec00dd
 
 
 
 
 
 
52b6367
 
 
 
 
 
 
 
 
 
 
 
cec00dd
6807ea3
cec00dd
 
 
 
 
 
 
 
e8c3b4b
cec00dd
e8c3b4b
cec00dd
f05ebc2
66a11b3
e8c3b4b
 
 
 
 
 
 
8a3d32e
f05ebc2
e8c3b4b
cec00dd
 
e8c3b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec00dd
 
 
 
 
 
 
b5edba5
 
cec00dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9fba6d
1657c25
b9fba6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3d32e
b9fba6d
52b6367
 
 
 
 
 
 
b9fba6d
 
 
 
 
8a3d32e
b9fba6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec00dd
 
 
 
 
 
 
 
 
6807ea3
f17e8ce
6807ea3
 
 
f17e8ce
6807ea3
 
 
 
 
 
 
796c1e7
52b6367
 
 
 
 
 
6807ea3
cec00dd
 
 
 
 
 
 
6807ea3
 
 
 
 
cec00dd
 
2ac340e
 
 
 
 
 
 
d3c5563
cec00dd
 
 
 
 
 
761dc11
66a11b3
761dc11
 
9828c0e
04b81d8
796c1e7
 
ec9c39a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66a11b3
8a3d32e
dbf76bc
cec00dd
 
 
51ae401
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from abc import ABC, abstractmethod
from datasets import load_dataset
import os
from dotenv import load_dotenv
import openai
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
import torch
from typing import List
from datetime import datetime
load_dotenv()
HF_TOKEN=os.getenv("HF_TOKEN")
OPENAI_KEY = os.getenv("OPENAI_API_KEY")

class BaseTask(ABC):
    _model_cache = {}  # Class-level cache for models and tokenizers

    def __init__(self, dataset_repo, model_name):
        self.dataset_repo = dataset_repo
        self.dataset = self.load_dataset_from_hf()

        device_count = torch.cuda.device_count()
        if device_count > 1:
            self.device = "auto"
            print(f"Using {device_count} GPUs with auto config.")
        elif device_count == 1:
            self.device = "cuda"
            print(f"Using {device_count} GPU with cuda config.")
        else:
            self.device = "cpu"
            print("No GPU found. Using CPU.")

        self.model, self.tokenizer = self.get_cached_model(model_name, self.device)
        openai.api_key = OPENAI_KEY


    @classmethod
    def get_cached_model(cls, model_name, device):
        """Ensures the same model and tokenizer are used for every instance of subclasses."""
        if model_name not in cls._model_cache:
            cls._model_cache[model_name] = cls.load_model(model_name, device)
        return cls._model_cache[model_name]
    
    @staticmethod
    def load_model(model_name: str, device):
        """Loads model and tokenizer once and caches it."""
        print(f"Loading model: {model_name}")
        start_time = datetime.now()
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map=device,
            token=HF_TOKEN,  # Replace with actual token
        )
        end_time = datetime.now()
        print(f"Model loaded in {(end_time - start_time).seconds} seconds.")
        print("Model loaded.")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return model, tokenizer

    # @staticmethod
    # def load_model(model_name: str, device, weight, dtype, base_model):
    #     """Loads model and tokenizer once and caches it."""
    #     print(f"Loading model: {model_name}")
    #     start_time = datetime.now()
    #     if weight == "Adapter":
    #         base_model_1 = AutoModelForCausalLM.from_pretrained(
    #             base_model,
    #             torch_dtype=dtype,
    #             device_map=device,
    #             token=HF_TOKEN,  # Replace with actual token
    #         )
    #         model = PeftModel.from_pretrained(base_model_1, base_model)
    #         tokenizer = AutoTokenizer.from_pretrained(base_model)
    #         end_time = datetime.now()
    #     else:
    #         model = AutoModelForCausalLM.from_pretrained(
    #             model_name,
    #             torch_dtype=dtype,
    #             device_map=device,
    #             token=HF_TOKEN,  # Replace with actual token
    #         )
    #         tokenizer = AutoTokenizer.from_pretrained(model_name)
    #         end_time = datetime.now()
    #     print(f"Model loaded in {(end_time - start_time).seconds} seconds.")
    #     print("Model loaded.")
        
    #     return model, tokenizer


    def generate_response_mcqa(self, msg, max_new_tokens=1, choices: List[str]=[]):
        # Ensure the tokenizer has a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token  # Use EOS token as PAD token

        inputs = self.tokenizer(msg, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask

        if self.model.config.pad_token_id is None:
            self.model.config.pad_token_id = self.tokenizer.eos_token_id

        # Get token IDs for answer choices
        valid_answers = choices
        valid_token_ids = [self.tokenizer.convert_tokens_to_ids(ans) for ans in valid_answers]

        class MultipleChoiceLogitsProcessor:
            def __call__(self, input_ids, scores):
                mask = torch.full_like(scores, float("-inf"))
                mask[:, valid_token_ids] = scores[:, valid_token_ids]  # Allow only valid tokens
                return mask

        logits_processor = LogitsProcessorList([MultipleChoiceLogitsProcessor()])

        output = self.model.generate(
            input_ids,
            attention_mask=attention_mask,  # Fix: Pass attention_mask to avoid warning
            max_new_tokens=max_new_tokens,
            logits_processor=logits_processor
        )
        answer = self.tokenizer.decode(output[0][-1])

        return answer
    
    def generate_response_mcqa_multi_token(self, msg, max_new_tokens=2, choices: list = []):
        """

        Handles multiple-choice questions where answers might have multiple tokens.

        """
        # Ensure tokenizer has proper special tokens set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.model.config.pad_token_id is None:
            self.model.config.pad_token_id = self.tokenizer.pad_token_id

        chat = [
                {"role": "user", "content": "You are a multiple choice question-answering chatbot. Do not give an answer that is not included in the choices. Only answer with letters like A, B, C, D..."},
                {"role": "assistant", "content": "I am ready to answer your questions. Feel free to ask anything.\n"},
                {"role": "user", "content": f"{msg}"},
            ]
        formatted_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        #print(formatted_chat)
        inputs = self.tokenizer(formatted_chat, return_tensors="pt", padding=True, truncation=True)

        if self.device == "auto":
            input_ids = inputs.input_ids
            attention_mask = inputs.attention_mask
        else:
            input_ids = inputs.input_ids.to(self.model.device)
            attention_mask = inputs.attention_mask.to(self.model.device)

        # Generate the sequence of letters starting from 'A'
        letters = [chr(ord('A') + i) for i in range(len(choices))]  # Create option letters A, B, C, D, E, ...
        encoded_choices = [self.tokenizer.encode(letter, add_special_tokens=False) for letter in letters]
        flattened_encoded_choices = [item for sublist in encoded_choices for item in sublist]  # Flatten the list
        #print(flattened_encoded_choices)

        allowed_tokens = flattened_encoded_choices
        allowed_tokens += self.get_chat_template_tokens() # Get the special chat tokens
        allowed_token_ids = set(allowed_tokens)  # Ensure uniqueness

        # Custom LogitsProcessor to restrict generation
        class RestrictToABCDLogitsProcessor(LogitsProcessor):
            def __call__(self, input_ids, scores):
                mask = torch.full_like(scores, float("-inf"))  # Block all tokens
                mask[:, list(allowed_token_ids)] = scores[:, list(allowed_token_ids)]  # Allow only A, B, C, D tokens
                return mask
        logits_processor = LogitsProcessorList([RestrictToABCDLogitsProcessor()])

        # Generate response
        output = self.model.generate(
            input_ids,
            do_sample=True,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            temperature=0.4,
            logits_processor=logits_processor,
        )
        generated_ids = output[0]  # The generated sequence including the prompt
        generated_tokens = generated_ids[len(input_ids[0]):]  # Exclude the input_ids part
        generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return generated_text

    def generate_response(self, prompt: str, max_new_tokens: int = 100) -> str:

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.model.config.pad_token_id is None:
            self.model.config.pad_token_id = self.tokenizer.eos_token_id

        chat = [
            {"role": "user", "content": "You are a helpful AI assistant."},
            {"role": "assistant", "content": "I am here to help you with any questions you may have."},
            {"role": "user", "content": prompt},
        ]
        
        formatted_chat = self.tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True
        )

        inputs = self.tokenizer(formatted_chat, return_tensors="pt", padding=True, truncation=True)

        if self.device == "auto":
            input_ids = inputs.input_ids
            attention_mask = inputs.attention_mask
        else:
            input_ids = inputs.input_ids.to(self.model.device)
            attention_mask = inputs.attention_mask.to(self.model.device)

        output = self.model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
        )

        generated_ids = output[0]
        prompt_len = input_ids.shape[1]
        generated_tokens = generated_ids[prompt_len:]
        result = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return result

    def get_chat_template_tokens(self):
        allowed_token_chat = [
            {"role": "user", "content": ""},
            {"role": "assistant", "content": ""}
        ]
        allowed_special_tokens = self.tokenizer.apply_chat_template(allowed_token_chat, tokenize=True)
        return allowed_special_tokens

    @abstractmethod
    def load_dataset_from_hf(self):
        """

        Define your own loading method if needed.

        :return: Dataset

        """
        print("Loading dataset from Hugging Face.")
        start_time = datetime.now()
        dataset= load_dataset(self.dataset_repo, token=HF_TOKEN, split="train")
        print("Dataset loaded.")
    
        # Load 50 from each dataset
        if len(dataset) > 50:
            dataset = dataset.shuffle(seed=42).select(range(50))
        end_time = datetime.now()
        print(f"Dataset loaded in {(end_time - start_time).seconds} seconds.")
        return dataset
    
    def load_dataset_lmjudge_from_hf(self):
        """

        Define your own loading method if needed.

        :return: Dataset

        """
        print("Loading dataset from Hugging Face.")
        start_time = datetime.now()
        dataset= load_dataset(self.dataset_repo, token=HF_TOKEN, split="train")
        print("Dataset loaded.")
        
        # Load 50 from each dataset
        if len(dataset) > 10:
            dataset = dataset.shuffle(seed=42).select(range(10))
        end_time = datetime.now()
        print(f"Dataset loaded in {(end_time - start_time).seconds} seconds.")
        return dataset

    @abstractmethod
    def evaluate(self):
        pass