Upload modeling_cocom.py
Browse files- modeling_cocom.py +2 -11
modeling_cocom.py
CHANGED
@@ -196,7 +196,6 @@ class COCOM(PreTrainedModel):
|
|
196 |
else:
|
197 |
compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
|
198 |
input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
|
199 |
-
inputs_embeds = inputs_embeds.to(compressed_embs.device)
|
200 |
return input_embeds
|
201 |
|
202 |
def compr_decoder(self, input_ids, attention_mask):
|
@@ -220,9 +219,7 @@ class COCOM(PreTrainedModel):
|
|
220 |
for i in range(batch_size):
|
221 |
for j in range(indices[i], indices[i + 1]):
|
222 |
start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
|
223 |
-
|
224 |
-
inputs_embeds = inputs_embeds.to(compressed_embs.device)
|
225 |
-
inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j].to(inputs_embeds.device)
|
226 |
return inputs_embeds
|
227 |
|
228 |
|
@@ -239,13 +236,7 @@ class COCOM(PreTrainedModel):
|
|
239 |
# dec_attention_mask: attention mask of dec_input_ids
|
240 |
|
241 |
# Perform compression with gradient tracking
|
242 |
-
|
243 |
-
inputs_embeds = self.compress_and_replace_emb(
|
244 |
-
enc_input_ids.to(self.decoder.device),
|
245 |
-
enc_attention_mask.to(self.decoder.device),
|
246 |
-
dec_input_ids.to(self.decoder.device),
|
247 |
-
)
|
248 |
-
|
249 |
# if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
|
250 |
if (self.training_form == "compressor") and (self.compr is None):
|
251 |
inputs_embeds = inputs_embeds.detach()
|
|
|
196 |
else:
|
197 |
compressed_embs = self.compr_decoder(enc_input_ids, enc_attention_mask)
|
198 |
input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
|
|
|
199 |
return input_embeds
|
200 |
|
201 |
def compr_decoder(self, input_ids, attention_mask):
|
|
|
219 |
for i in range(batch_size):
|
220 |
for j in range(indices[i], indices[i + 1]):
|
221 |
start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
|
222 |
+
inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
|
|
|
|
|
223 |
return inputs_embeds
|
224 |
|
225 |
|
|
|
236 |
# dec_attention_mask: attention mask of dec_input_ids
|
237 |
|
238 |
# Perform compression with gradient tracking
|
239 |
+
inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
# if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
|
241 |
if (self.training_form == "compressor") and (self.compr is None):
|
242 |
inputs_embeds = inputs_embeds.detach()
|