|
from utils.arguments import TrainingArguments, DataArguments, LoraArguments |
|
|
|
def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
def reshape_model_embedding(tokenizer, model): |
|
token_length = len(tokenizer) |
|
embedding_length = model.get_input_embeddings().num_embeddings |
|
if token_length != embedding_length: |
|
num_new_tokens = token_length - embedding_length |
|
model.resize_token_embeddings(len(tokenizer)) |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
class BaseModel: |
|
def __init__( |
|
self, |
|
model_path, |
|
training_args: TrainingArguments, |
|
data_args: DataArguments, |
|
lora_args: LoraArguments, |
|
use_caption = None, |
|
): |
|
self.model_path = model_path |
|
self.training_args = training_args |
|
self.data_args = data_args |
|
self.lora_args = lora_args |
|
self.use_caption = use_caption |
|
|
|
self.load_model_tokenizer() |
|
self.configure_special_tokens() |
|
self.configure_training_args() |
|
self.configure_peft() |
|
try: |
|
self.model.print_trainable_parameters() |
|
except: |
|
pass |
|
print('lljllj self model use_cache :', self.model.config.use_cache, flush=True) |
|
|
|
def configure_special_tokens(self): |
|
if self.use_caption and self.use_caption.get('text_pool', 'eot') == 'eot': |
|
eot_token = '[EOT]' |
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(additional_special_tokens=[eot_token]), |
|
tokenizer=self.tokenizer, |
|
model=self.model) |
|
else: |
|
reshape_model_embedding(self.tokenizer, self.model) |
|
self.model.tokenizer = self.tokenizer |
|
|
|
def load_model_tokenizer(self): |
|
raise NotImplementedError |
|
|
|
def configure_training_args(self): |
|
raise NotImplementedError |
|
|
|
def configure_peft(self): |
|
raise NotImplementedError |
|
|
|
def get_model_tokenizer(self): |
|
return self.model, self.tokenizer |
|
|
|
def get_model_processor(self): |
|
return self.model, self.processor |