File size: 3,509 Bytes
11b9c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from utils.arguments import TrainingArguments, DataArguments, LoraArguments

def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) # TODO: maybe save special tokens in tokenizer
    model.resize_token_embeddings(len(tokenizer)) # this set lm_head and embed_tokens requires_grad = True

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

def reshape_model_embedding(tokenizer, model):
    token_length = len(tokenizer)
    embedding_length = model.get_input_embeddings().num_embeddings
    if token_length != embedding_length:
        num_new_tokens = token_length - embedding_length
        model.resize_token_embeddings(len(tokenizer)) # this set lm_head and embed_tokens requires_grad = True
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

class BaseModel:
    def __init__(
        self,
        model_path,
        training_args: TrainingArguments,
        data_args: DataArguments,
        lora_args: LoraArguments,
        use_caption = None,
    ):
        self.model_path = model_path
        self.training_args = training_args
        self.data_args = data_args
        self.lora_args = lora_args
        self.use_caption = use_caption

        self.load_model_tokenizer()
        self.configure_special_tokens()
        self.configure_training_args()
        self.configure_peft()
        try:
            self.model.print_trainable_parameters()
        except:
            pass
        print('lljllj                        self model use_cache :', self.model.config.use_cache, flush=True)

    def configure_special_tokens(self):
        if self.use_caption and self.use_caption.get('text_pool', 'eot') == 'eot': 
            eot_token = '[EOT]'
            smart_tokenizer_and_embedding_resize(
                special_tokens_dict=dict(additional_special_tokens=[eot_token]),
                tokenizer=self.tokenizer,
                model=self.model)
        else:
            reshape_model_embedding(self.tokenizer, self.model)
        self.model.tokenizer = self.tokenizer

    def load_model_tokenizer(self):
        raise NotImplementedError

    def configure_training_args(self):
        raise NotImplementedError
    
    def configure_peft(self):
        raise NotImplementedError

    def get_model_tokenizer(self):
        return self.model, self.tokenizer
    
    def get_model_processor(self):
        return self.model, self.processor