Tu2003716 Plasmarine commited on
Commit
d6b992d
1 Parent(s): a57bc51

Longformer attn config (#4)

Browse files

- Longformer attn config (929ceeb15a1635aa4858e9e2e20a6860ee4eb32d)


Co-authored-by: Plasmarine <[email protected]>

Files changed (1) hide show
  1. modeling_cocom.py +48 -22
modeling_cocom.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel,LongformerForCausalLM, LongformerTokenizer
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
@@ -71,7 +71,8 @@ class COCOMConfig(PretrainedConfig):
71
  lora = False,
72
  training_form="both",
73
  lora_r=16,
74
- attn_implementation="eager",
 
75
  device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
@@ -95,6 +96,28 @@ class COCOM(PreTrainedModel):
95
  super().__init__(cfg)
96
  # define models
97
  attn_impl = cfg.attn_implementation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # model could be loaded in three quantization modes: no, int4, int8
99
  if cfg.quantization == "no":
100
  self.decoder = AutoModelForCausalLM.from_pretrained(
@@ -193,15 +216,20 @@ class COCOM(PreTrainedModel):
193
  self.compr_rate = cfg.compr_rate
194
  self.local_rank = os.getenv('LOCAL_RANK', '0')
195
 
196
- def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
197
  indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
 
 
198
  if self.compr:
199
  compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
200
- input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
201
  else:
202
  compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
203
- input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
 
 
 
204
  return input_embeds
 
205
 
206
  def compr_decoder(self, input_ids, attention_mask):
207
  emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
@@ -212,19 +240,23 @@ class COCOM(PreTrainedModel):
212
  def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
213
  # Embed the decoder input
214
  inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
 
 
215
  num_embs = compressed_embs.size(1)
216
- if self.sep:
217
- slot_len = num_embs + 1
218
- else:
219
- slot_len = num_embs
220
- # get first mem_token inidices
221
  first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
222
  batch_size = inputs_embeds.size(0)
223
- # for each example in batch, replace them with compressed embeddings
 
224
  for i in range(batch_size):
225
  for j in range(indices[i], indices[i + 1]):
226
- start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
227
  inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
 
228
  return inputs_embeds
229
 
230
 
@@ -235,19 +267,13 @@ class COCOM(PreTrainedModel):
235
  dec_attention_mask: torch.LongTensor = None,
236
  labels: torch.LongTensor = None):
237
 
238
- # enc_input_ids: stores the contexts, should be flattened from all queries before input, dimention (batch_size*generation_top_k, token_length)
239
- # enc_attention_mask: attention mask of enc_input_ids
240
- # dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, token_length)
241
- # dec_attention_mask: attention mask of dec_input_ids
242
-
243
- # Perform compression with gradient tracking
244
- inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
245
 
246
- # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
247
  if (self.training_form == "compressor") and (self.compr is None):
248
- inputs_embeds = inputs_embeds.detach()
249
 
250
- # decoding
251
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
252
 
253
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel,LongformerForCausalLM, LongformerTokenizer, LongformerConfig
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
 
71
  lora = False,
72
  training_form="both",
73
  lora_r=16,
74
+ attn_implementation="longformer",
75
+ attention_window=512,
76
  device_map = "cuda",
77
  **kwargs):
78
  super().__init__(**kwargs)
 
96
  super().__init__(cfg)
97
  # define models
98
  attn_impl = cfg.attn_implementation
99
+
100
+ if cfg.attn_implementation == "longformer":
101
+ # Initialize Longformer
102
+ longformer_config = LongformerConfig.from_pretrained(cfg.decoder_model_name)
103
+ longformer_config.attention_window = 512 # Modify based on context window size
104
+ self.decoder = LongformerForCausalLM.from_pretrained(
105
+ cfg.decoder_model_name,
106
+ config=longformer_config,
107
+ torch_dtype=torch.float16,
108
+ low_cpu_mem_usage=True,
109
+ device_map=cfg.device_map
110
+ )
111
+ else:
112
+ # Original decoder initialization
113
+ self.decoder = AutoModelForCausalLM.from_pretrained(
114
+ cfg.decoder_model_name,
115
+ torch_dtype=torch.float16,
116
+ attn_implementation=attn_impl,
117
+ low_cpu_mem_usage=True,
118
+ device_map=cfg.device_map
119
+ )
120
+
121
  # model could be loaded in three quantization modes: no, int4, int8
122
  if cfg.quantization == "no":
123
  self.decoder = AutoModelForCausalLM.from_pretrained(
 
216
  self.compr_rate = cfg.compr_rate
217
  self.local_rank = os.getenv('LOCAL_RANK', '0')
218
 
219
+ def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask):
220
  indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
221
+
222
+ # Perform compression
223
  if self.compr:
224
  compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
 
225
  else:
226
  compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
227
+
228
+ # Replace embeddings with compressed ones
229
+ input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, dec_attention_mask, indices)
230
+
231
  return input_embeds
232
+
233
 
234
  def compr_decoder(self, input_ids, attention_mask):
235
  emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
 
240
  def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
241
  # Embed the decoder input
242
  inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
243
+
244
+ # Number of compressed embeddings
245
  num_embs = compressed_embs.size(1)
246
+
247
+ # Define slot length for memory tokens
248
+ slot_len = num_embs + 1 if self.sep else num_embs
249
+
250
+ # Find the first memory token indices
251
  first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
252
  batch_size = inputs_embeds.size(0)
253
+
254
+ # Replace memory tokens with compressed embeddings
255
  for i in range(batch_size):
256
  for j in range(indices[i], indices[i + 1]):
257
+ start_idx = first_mem_token_indices[i].item() + (j - indices[i]) * slot_len
258
  inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
259
+
260
  return inputs_embeds
261
 
262
 
 
267
  dec_attention_mask: torch.LongTensor = None,
268
  labels: torch.LongTensor = None):
269
 
270
+ inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask)
 
 
 
 
 
 
271
 
272
+ # Detach inputs_embeds if training compressor only
273
  if (self.training_form == "compressor") and (self.compr is None):
274
+ inputs_embeds = inputs_embeds.detach()
275
 
276
+ # Pass through the decoder
277
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
278
 
279
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}