File size: 6,930 Bytes
ff4fdee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb0bcb
 
ff4fdee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import copy
from doctest import ELLIPSIS_MARKER
from functools import partial
import json
from turtle import forward, shape
import einops
import torch
from torch import nn

from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \
 BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer 
from transformers import BitsAndBytesConfig

from peft import prepare_model_for_kbit_training
from peft import LoraConfig
from peft import get_peft_model

        
from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule
import math
from ipdb import set_trace

class mixEmbed(nn.Module):
    def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lm_embed = lm_embed
        self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes
        
    def forward(self, input_ids):
        text_ids = torch.clamp(input_ids.clone(), 0).long()
        
        au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long()
        text_embeds = self.lm_embed(text_ids)
        au_embeds = self.audio_embeddings[au_ids]
        with torch.no_grad():
            embed_mask = (input_ids > 0)
        mix_embeds = au_embeds.clone()
        mix_embeds[embed_mask] = text_embeds[embed_mask]
        return mix_embeds
 

class LMDecoder(nn.Module):
    def __init__(self,
                # num_patches=196,
                img_size=(80,512),
                patch_size:int=16,
                in_chans:int=3,
                embed_dim=1024, # encoder embed dim
                decoder_embed_dim=512,
                norm_cfg=dict(type='LN', eps=1e-6),
                # patch_resolution=14,
                decoder_type='gpt2',
                freeze_decoder=True,
                additional_layer:int=0,
                ):
        super().__init__()
        self.decoder_type = decoder_type
        self.load_lm()
        
        self.lm_embed = self.lm.get_input_embeddings()
        try:
            self.lm_pos_embed = self.lm.get_position_embeddings()
        except NotImplementedError:
            self.lm_pos_embed = None # rotrary embeds
            
        
        if hasattr(self.lm,'embed_dim'):
            self.embed_dim = self.lm.embed_dim
        else:
            self.embed_dim = decoder_embed_dim
        
        # self.asLM = asLM # if generates tokens rather than hidden states
        # if self.asLM: # TODO: 当年写这个是为啥?
        #     self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False))
        self.freeze_decoder = False
        if True:
            for para in self.lm.parameters():
                para.requires_grad = False
        
    def load_lm(self):
        ## ---------------------LM setting----------------------
        self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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')
        self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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')
       
        
    def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs):
        mix_embed = mixEmbed(self.lm_embed, flatten_embs)
        self.lm.set_input_embeddings(mix_embed) # modification of the lm embed 
        output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs) 
        self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed 
        return output

    def generate(self, input_ids, flatten_embs):
        mix_embed = mixEmbed(self.lm_embed, flatten_embs)
        self.lm.set_input_embeddings(mix_embed) # modification of the lm embed 
        outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False)
        # outputs = self.lm.generate(input_ids=input_ids, 
        #                            max_new_tokens=1024, 
        #                            do_sample=True,
        #                            temperature=1.5,
        #                            num_beams=1,
        #                            top_p=0.9,
        #                            top_k=3,
        #                            use_cache=False)
        self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed 
        return outputs
'''
## infer params
max_input_tokens: 40
batch_size_test: 16
max_new_tokens: 64
min_length: 2
num_beams: 5
length_penalty: -2.0
top_p: 0.9
top_k: 3
no_repeat_ngram_size: 2
apply_lemmatizer: False
use_nucleus_sampling: True
'''

class LMDecoder_qlora(LMDecoder):
    def __init__(self,
                # num_patches=196,
                img_size=(80,512),
                patch_size:int=16,
                in_chans:int=3,
                embed_dim=1024, # encoder embed dim
                decoder_embed_dim=512,
                norm_cfg=dict(type='LN', eps=1e-6),
                # patch_resolution=14,
                decoder_type='gpt2',
                freeze_decoder=True,
                additional_layer:int=0,
                ):
        super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer)
        
    def load_lm(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
        self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
        double_quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            )
        model = AutoModelForCausalLM.from_pretrained(self.decoder_type, 
                                                    #  device_map='auto', # if remove, can not add lora 
                                                    # load_in_4bit=True,# if remove, can not add lora 
                                                    # # torch_dtype=torch.bfloat16,
                                                    #  quantization_config=double_quant_config, # if remove, can not add lora 
                                                     trust_remote_code=True )

        model.gradient_checkpointing_enable()
        model = prepare_model_for_kbit_training(model)
        lora_config = LoraConfig(
            r=8, 
            lora_alpha=32, 
            target_modules=["query_key_value"], 
            lora_dropout=0.05, 
            bias="none", 
            task_type="CAUSAL_LM"
        )

        self.lm = get_peft_model(model, lora_config)
        self.lm.print_trainable_parameters()