Tu2003716 commited on
Commit
0706b6b
1 Parent(s): 1010809

Upload modeling_cocom.py

Browse files
Files changed (1) hide show
  1. modeling_cocom.py +2 -11
modeling_cocom.py CHANGED
@@ -196,7 +196,6 @@ class COCOM(PreTrainedModel):
196
  else:
197
  compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
198
  input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
199
- inputs_embeds = inputs_embeds.to(compressed_embs.device)
200
  return input_embeds
201
 
202
  def compr_decoder(self, input_ids, attention_mask):
@@ -220,9 +219,7 @@ class COCOM(PreTrainedModel):
220
  for i in range(batch_size):
221
  for j in range(indices[i], indices[i + 1]):
222
  start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
223
- # inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
224
- inputs_embeds = inputs_embeds.to(compressed_embs.device)
225
- inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j].to(inputs_embeds.device)
226
  return inputs_embeds
227
 
228
 
@@ -239,13 +236,7 @@ class COCOM(PreTrainedModel):
239
  # dec_attention_mask: attention mask of dec_input_ids
240
 
241
  # Perform compression with gradient tracking
242
- # inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
243
- inputs_embeds = self.compress_and_replace_emb(
244
- enc_input_ids.to(self.decoder.device),
245
- enc_attention_mask.to(self.decoder.device),
246
- dec_input_ids.to(self.decoder.device),
247
- )
248
-
249
  # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
250
  if (self.training_form == "compressor") and (self.compr is None):
251
  inputs_embeds = inputs_embeds.detach()
 
196
  else:
197
  compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
198
  input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
 
199
  return input_embeds
200
 
201
  def compr_decoder(self, input_ids, attention_mask):
 
219
  for i in range(batch_size):
220
  for j in range(indices[i], indices[i + 1]):
221
  start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
222
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
 
 
223
  return inputs_embeds
224
 
225
 
 
236
  # dec_attention_mask: attention mask of dec_input_ids
237
 
238
  # Perform compression with gradient tracking
239
+ inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
 
 
 
 
 
 
240
  # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
241
  if (self.training_form == "compressor") and (self.compr is None):
242
  inputs_embeds = inputs_embeds.detach()