from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel, LongformerForCausalLM, LongformerTokenizer from linformer.attention import LinformerSelfAttention import torch import math from peft import get_peft_model, LoraConfig, TaskType import os # Freeze model function (unchanged) def freeze_model(model): for param in model.parameters(): param.requires_grad = False # BERT_Compressor remains the same as you are not modifying it for Linformer class BERT_Compressor(torch.nn.Module): def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size): super().__init__() self.model_name = compr_model_name self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16) self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True) self.compr_rate = compr_rate self.compressing_mode = compr_linear_type if self.compressing_mode == 'concat': self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size) elif self.compressing_mode == 'mean': self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size) self.linear = self.linear.float16() def forward(self, input_ids, attention_mask): segment_compress_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) num_embs = math.ceil(input_ids.size(1) / self.compr_rate) all_hidden_states_emb = list() if self.compressing_mode == 'concat': for segment_idx in range(num_embs): start_idx = segment_idx * self.compr_rate end_idx = (segment_idx + 1) * self.compr_rate hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :] hidden_state_concat = torch.flatten(hidden_state, start_dim=1) #batch_size, hidden_state_dim * compression_rate all_hidden_states_emb.append(hidden_state_concat) elif self.compressing_mode == "mean": for segment_idx in range(num_embs): start_idx = segment_idx * self.compr_rate end_idx = (segment_idx + 1) * self.compr_rate hidden_state = segment_compress_outputs.hidden_states[-1][:, start_idx:end_idx, :] all_hidden_states_emb.append(hidden_state) all_hidden_states_emb_cat = torch.stack(all_hidden_states_emb, dim=1) transformed_embeds = self.linear(all_hidden_states_emb_cat) if self.compressing_mode == "mean": transformed_embeds = torch.mean(transformed_embeds, dim=2) return transformed_embeds # Modify COCOMConfig to support Linformer class COCOMConfig(PretrainedConfig): model_type = "COCOM" def __init__(self, decoder_model_name="meta-llama/Llama-2-7b-chat-hf", quantization = 'no', generation_top_k = 1, sep = False, compr_model_name = "bert-base-uncased", compr_rate = 64, compr_linear_type = 'concat', lora = False, training_form="both", lora_r=16, attn_implementation="linformer", # Change default to Linformer device_map = "cuda", **kwargs): super().__init__(**kwargs) self.decoder_model_name = decoder_model_name self.quantization = quantization self.generation_top_k = generation_top_k self.sep = sep self.compr_model_name = compr_model_name self.compr_rate = compr_rate self.compr_linear_type = compr_linear_type self.lora = lora self.training_form = training_form self.lora_r = lora_r self.attn_implementation = attn_implementation self.device_map = device_map # Modify COCOM model to use Linformer in the attention layer class COCOM(PreTrainedModel): config_class = COCOMConfig def __init__(self, cfg): super().__init__(cfg) attn_impl = cfg.attn_implementation # Load the model (decoder) in standard quantization or Linformer self.decoder = AutoModelForCausalLM.from_pretrained( cfg.decoder_model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=cfg.device_map ) # Replace decoder's attention mechanism with LinformerSelfAttention if configured if attn_impl == 'linformer': self._replace_attention_with_linformer() # Initialize other parts of the model (compression, LoRA, etc.) self.compr = BERT_Compressor(cfg.compr_model_name, cfg.compr_rate, cfg.compr_linear_type, self.decoder.config.hidden_size) if cfg.lora: self._apply_lora(cfg.lora_r) self.decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left') self.decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '', '']}) def _replace_attention_with_linformer(self): # Replace all attention layers with LinformerSelfAttention in the model for layer in self.decoder.transformer.h: layer.attn = LinformerSelfAttention( dim=layer.attn.attn.in_proj_weight.shape[0], num_heads=layer.attn.num_attention_heads, dropout=0.1, n_heads=layer.attn.num_attention_heads, d_head=layer.attn.attn.in_proj_weight.shape[0] // layer.attn.num_attention_heads ) def _apply_lora(self, lora_r): # Apply LoRA as per your configuration peft_config = LoraConfig( task_type="CAUSAL_LM", r=lora_r, lora_alpha=2 * lora_r, target_modules='all-linear', lora_dropout=0.1, ) self.decoder = get_peft_model(self.decoder, peft_config) def forward(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask, labels): inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids) decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels) return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits} def generate(self, model_input, max_new_tokens=128): device = self.decoder.device enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask'] inputs_embeds = self.compress_and_replace_emb(enc_input_ids.to(device), enc_attention_mask.to(device), dec_input_ids.to(device)) output_ids = self.decoder.generate( inputs_embeds=inputs_embeds.to(device), attention_mask=dec_attention_mask.to(device), do_sample=False, top_p=None, max_new_tokens=min(max_new_tokens, 4096) ) decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True) return decoded