File size: 7,540 Bytes
2167951
 
027ecf9
 
 
 
 
2167951
 
027ecf9
 
 
 
 
2167951
027ecf9
 
 
2167951
9af0971
027ecf9
2167951
 
027ecf9
2167951
027ecf9
 
 
9af0971
027ecf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2167951
 
027ecf9
 
00a78e8
027ecf9
 
 
 
 
 
 
 
 
2167951
8d09921
027ecf9
 
2167951
 
 
 
 
 
 
 
 
 
9d3c242
8d09921
027ecf9
2167951
 
027ecf9
 
 
 
9d3c242
2167951
 
 
 
 
 
 
 
027ecf9
2167951
 
 
027ecf9
2167951
 
027ecf9
2167951
027ecf9
 
 
 
2167951
 
 
 
 
 
 
 
 
 
e6fc89b
2167951
 
 
 
 
 
 
 
 
 
 
 
e6fc89b
027ecf9
 
 
 
 
 
 
 
 
 
 
 
2167951
 
027ecf9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel, LongformerForCausalLM, LongformerTokenizer
from linformer.attention import LinformerSelfAttention
import torch
import math 
from peft import get_peft_model, LoraConfig, TaskType
import os


# Freeze model function (unchanged)
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False


# BERT_Compressor remains the same as you are not modifying it for Linformer
class BERT_Compressor(torch.nn.Module):
    def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
        super().__init__()
        self.model_name = compr_model_name
        self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
        self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True) 
        self.compr_rate = compr_rate
        self.compressing_mode = compr_linear_type

        if self.compressing_mode == 'concat':
            self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size) 
        elif self.compressing_mode == 'mean':
            self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
        self.linear = self.linear.float16()

    def forward(self, input_ids, attention_mask):
        segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) 
        num_embs = math.ceil(input_ids.size(1) / self.compr_rate)
        all_hidden_states_emb = list()
        if self.compressing_mode == 'concat':
            for segment_idx in range(num_embs):
                start_idx = segment_idx * self.compr_rate
                end_idx = (segment_idx + 1) * self.compr_rate
                hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
                hidden_state_concat = torch.flatten(hidden_state, start_dim=1) #batch_size, hidden_state_dim * compression_rate
                all_hidden_states_emb.append(hidden_state_concat)
        elif self.compressing_mode == "mean":
            for segment_idx in range(num_embs):
                start_idx = segment_idx * self.compr_rate
                end_idx = (segment_idx + 1) * self.compr_rate
                hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :]
                all_hidden_states_emb.append(hidden_state)
        all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1)
        transformed_embeds = self.linear(all_hidden_states_emb_cat)
        
        if self.compressing_mode == "mean":
            transformed_embeds = torch.mean(transformed_embeds, dim=2)

        return  transformed_embeds


# Modify COCOMConfig to support Linformer
class COCOMConfig(PretrainedConfig):
    model_type = "COCOM"
    def __init__(self,

                decoder_model_name="meta-llama/Llama-2-7b-chat-hf",

                quantization = 'no', 

                generation_top_k = 1, 

                sep = False,

                compr_model_name = "bert-base-uncased", 

                compr_rate = 64,

                compr_linear_type = 'concat',

                lora = False,

                training_form="both",

                lora_r=16,

                attn_implementation="linformer",  # Change default to Linformer

                device_map = "cuda",

                 **kwargs):
        super().__init__(**kwargs)
        self.decoder_model_name = decoder_model_name 
        self.quantization = quantization 
        self.generation_top_k = generation_top_k 
        self.sep = sep 
        self.compr_model_name = compr_model_name 
        self.compr_rate = compr_rate 
        self.compr_linear_type = compr_linear_type 
        self.lora = lora 
        self.training_form = training_form 
        self.lora_r = lora_r 
        self.attn_implementation = attn_implementation
        self.device_map = device_map


# Modify COCOM model to use Linformer in the attention layer
class COCOM(PreTrainedModel):
    config_class = COCOMConfig
    def __init__(self, cfg):
        super().__init__(cfg)
        attn_impl = cfg.attn_implementation

        # Load the model (decoder) in standard quantization or Linformer
        self.decoder = AutoModelForCausalLM.from_pretrained(
            cfg.decoder_model_name, 
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map=cfg.device_map
        )
        
        # Replace decoder's attention mechanism with LinformerSelfAttention if configured
        if attn_impl == 'linformer':
            self._replace_attention_with_linformer()

        # Initialize other parts of the model (compression, LoRA, etc.)
        self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size)
        if cfg.lora:
            self._apply_lora(cfg.lora_r)

        self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
        self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
        
    def _replace_attention_with_linformer(self):
        # Replace all attention layers with LinformerSelfAttention in the model
        for layer in self.decoder.transformer.h:
            layer.attn = LinformerSelfAttention(
                dim=layer.attn.attn.in_proj_weight.shape[0],
                num_heads=layer.attn.num_attention_heads,
                dropout=0.1,
                n_heads=layer.attn.num_attention_heads,
                d_head=layer.attn.attn.in_proj_weight.shape[0] // layer.attn.num_attention_heads
            )

    def _apply_lora(self, lora_r):
        # Apply LoRA as per your configuration
        peft_config = LoraConfig(
            task_type="CAUSAL_LM",
            r=lora_r,
            lora_alpha=2 * lora_r,
            target_modules='all-linear',
            lora_dropout=0.1,
        )
        self.decoder = get_peft_model(self.decoder, peft_config)

    def forward(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask, labels):
        inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
        decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
        return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}

    def generate(self, model_input, max_new_tokens=128):
        device = self.decoder.device
        enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
        inputs_embeds = self.compress_and_replace_emb(enc_input_ids.to(device), enc_attention_mask.to(device), dec_input_ids.to(device))
        output_ids = self.decoder.generate(
            inputs_embeds=inputs_embeds.to(device), 
            attention_mask=dec_attention_mask.to(device),
            do_sample=False,
            top_p=None,
            max_new_tokens=min(max_new_tokens, 4096)
        )
        decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return decoded