Plasmarine commited on
Commit
fc42df6
1 Parent(s): d6b992d
Files changed (1) hide show
  1. modeling_cocom.py +23 -49
modeling_cocom.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel,LongformerForCausalLM, LongformerTokenizer, LongformerConfig
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
@@ -71,8 +71,7 @@ class COCOMConfig(PretrainedConfig):
71
  lora = False,
72
  training_form="both",
73
  lora_r=16,
74
- attn_implementation="longformer",
75
- attention_window=512,
76
  device_map = "cuda",
77
  **kwargs):
78
  super().__init__(**kwargs)
@@ -96,28 +95,6 @@ class COCOM(PreTrainedModel):
96
  super().__init__(cfg)
97
  # define models
98
  attn_impl = cfg.attn_implementation
99
-
100
- if cfg.attn_implementation == "longformer":
101
- # Initialize Longformer
102
- longformer_config = LongformerConfig.from_pretrained(cfg.decoder_model_name)
103
- longformer_config.attention_window = 512 # Modify based on context window size
104
- self.decoder = LongformerForCausalLM.from_pretrained(
105
- cfg.decoder_model_name,
106
- config=longformer_config,
107
- torch_dtype=torch.float16,
108
- low_cpu_mem_usage=True,
109
- device_map=cfg.device_map
110
- )
111
- else:
112
- # Original decoder initialization
113
- self.decoder = AutoModelForCausalLM.from_pretrained(
114
- cfg.decoder_model_name,
115
- torch_dtype=torch.float16,
116
- attn_implementation=attn_impl,
117
- low_cpu_mem_usage=True,
118
- device_map=cfg.device_map
119
- )
120
-
121
  # model could be loaded in three quantization modes: no, int4, int8
122
  if cfg.quantization == "no":
123
  self.decoder = AutoModelForCausalLM.from_pretrained(
@@ -216,20 +193,15 @@ class COCOM(PreTrainedModel):
216
  self.compr_rate = cfg.compr_rate
217
  self.local_rank = os.getenv('LOCAL_RANK', '0')
218
 
219
- def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask):
220
  indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
221
-
222
- # Perform compression
223
  if self.compr:
224
  compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
 
225
  else:
226
  compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
227
-
228
- # Replace embeddings with compressed ones
229
- input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, dec_attention_mask, indices)
230
-
231
  return input_embeds
232
-
233
 
234
  def compr_decoder(self, input_ids, attention_mask):
235
  emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
@@ -240,23 +212,19 @@ class COCOM(PreTrainedModel):
240
  def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
241
  # Embed the decoder input
242
  inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
243
-
244
- # Number of compressed embeddings
245
  num_embs = compressed_embs.size(1)
246
-
247
- # Define slot length for memory tokens
248
- slot_len = num_embs + 1 if self.sep else num_embs
249
-
250
- # Find the first memory token indices
251
  first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
252
  batch_size = inputs_embeds.size(0)
253
-
254
- # Replace memory tokens with compressed embeddings
255
  for i in range(batch_size):
256
  for j in range(indices[i], indices[i + 1]):
257
- start_idx = first_mem_token_indices[i].item() + (j - indices[i]) * slot_len
258
  inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
259
-
260
  return inputs_embeds
261
 
262
 
@@ -267,13 +235,19 @@ class COCOM(PreTrainedModel):
267
  dec_attention_mask: torch.LongTensor = None,
268
  labels: torch.LongTensor = None):
269
 
270
- inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask)
 
 
 
 
 
 
271
 
272
- # Detach inputs_embeds if training compressor only
273
  if (self.training_form == "compressor") and (self.compr is None):
274
- inputs_embeds = inputs_embeds.detach()
275
 
276
- # Pass through the decoder
277
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
278
 
279
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
@@ -289,7 +263,7 @@ class COCOM(PreTrainedModel):
289
  attention_mask=dec_attention_mask.to(device),
290
  do_sample=False,
291
  top_p=None,
292
- max_new_tokens=min(max_new_tokens, 4096)
293
  )
294
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
295
  return decoded
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
2
  import torch
3
  import math
4
  from peft import get_peft_model, LoraConfig, TaskType
 
71
  lora = False,
72
  training_form="both",
73
  lora_r=16,
74
+ attn_implementation="eager",
 
75
  device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
 
95
  super().__init__(cfg)
96
  # define models
97
  attn_impl = cfg.attn_implementation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # model could be loaded in three quantization modes: no, int4, int8
99
  if cfg.quantization == "no":
100
  self.decoder = AutoModelForCausalLM.from_pretrained(
 
193
  self.compr_rate = cfg.compr_rate
194
  self.local_rank = os.getenv('LOCAL_RANK', '0')
195
 
196
+ def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
197
  indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
 
 
198
  if self.compr:
199
  compressed_embs = self.compr(enc_input_ids, enc_attention_mask)
200
+ input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
201
  else:
202
  compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
203
+ input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
 
 
 
204
  return input_embeds
 
205
 
206
  def compr_decoder(self, input_ids, attention_mask):
207
  emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
 
212
  def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
213
  # Embed the decoder input
214
  inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
 
 
215
  num_embs = compressed_embs.size(1)
216
+ if self.sep:
217
+ slot_len = num_embs + 1
218
+ else:
219
+ slot_len = num_embs
220
+ # get first mem_token inidices
221
  first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_id).int(), dim=1)
222
  batch_size = inputs_embeds.size(0)
223
+ # for each example in batch, replace them with compressed embeddings
 
224
  for i in range(batch_size):
225
  for j in range(indices[i], indices[i + 1]):
226
+ start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
227
  inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
 
228
  return inputs_embeds
229
 
230
 
 
235
  dec_attention_mask: torch.LongTensor = None,
236
  labels: torch.LongTensor = None):
237
 
238
+ # enc_input_ids: stores the contexts, should be flattened from all queries before input, dimention (batch_size*generation_top_k, token_length)
239
+ # enc_attention_mask: attention mask of enc_input_ids
240
+ # dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, token_length)
241
+ # dec_attention_mask: attention mask of dec_input_ids
242
+
243
+ # Perform compression with gradient tracking
244
+ inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
245
 
246
+ # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
247
  if (self.training_form == "compressor") and (self.compr is None):
248
+ inputs_embeds = inputs_embeds.detach()
249
 
250
+ # decoding
251
  decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels)
252
 
253
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
 
263
  attention_mask=dec_attention_mask.to(device),
264
  do_sample=False,
265
  top_p=None,
266
+ max_new_tokens=max_new_tokens
267
  )
268
  decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
269
  return decoded