Tu2003716 commited on
Commit
00a78e8
1 Parent(s): 0706b6b

Upload modeling_cocom.py

Browse files
Files changed (1) hide show
  1. modeling_cocom.py +5 -1
modeling_cocom.py CHANGED
@@ -61,7 +61,7 @@ class COCOMConfig(PretrainedConfig):
61
 
62
  model_type = "COCOM"
63
  def __init__(self,
64
- decoder_model_name="google-t5/t5-base",
65
  quantization = 'no',
66
  generation_top_k = 1,
67
  sep = False,
@@ -100,6 +100,7 @@ class COCOM(PreTrainedModel):
100
  torch_dtype=torch.float16,
101
  attn_implementation=attn_impl,
102
  low_cpu_mem_usage = True,
 
103
  )
104
  elif cfg.quantization == "int4":
105
  quant_config = BitsAndBytesConfig(
@@ -116,6 +117,7 @@ class COCOM(PreTrainedModel):
116
  resume_download=True,
117
  low_cpu_mem_usage = True,
118
  trust_remote_code=True,
 
119
  )
120
  elif cfg.quantization == "int8":
121
  quant_config = BitsAndBytesConfig(
@@ -132,6 +134,7 @@ class COCOM(PreTrainedModel):
132
  resume_download=True,
133
  low_cpu_mem_usage = True,
134
  trust_remote_code=True,
 
135
  )
136
  else:
137
  raise NotImplementedError()
@@ -237,6 +240,7 @@ class COCOM(PreTrainedModel):
237
 
238
  # Perform compression with gradient tracking
239
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
 
240
  # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
241
  if (self.training_form == "compressor") and (self.compr is None):
242
  inputs_embeds = inputs_embeds.detach()
 
61
 
62
  model_type = "COCOM"
63
  def __init__(self,
64
+ decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
65
  quantization = 'no',
66
  generation_top_k = 1,
67
  sep = False,
 
100
  torch_dtype=torch.float16,
101
  attn_implementation=attn_impl,
102
  low_cpu_mem_usage = True,
103
+ device_map='auto'
104
  )
105
  elif cfg.quantization == "int4":
106
  quant_config = BitsAndBytesConfig(
 
117
  resume_download=True,
118
  low_cpu_mem_usage = True,
119
  trust_remote_code=True,
120
+ device_map='auto'
121
  )
122
  elif cfg.quantization == "int8":
123
  quant_config = BitsAndBytesConfig(
 
134
  resume_download=True,
135
  low_cpu_mem_usage = True,
136
  trust_remote_code=True,
137
+ device_map='auto'
138
  )
139
  else:
140
  raise NotImplementedError()
 
240
 
241
  # Perform compression with gradient tracking
242
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
243
+
244
  # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
245
  if (self.training_form == "compressor") and (self.compr is None):
246
  inputs_embeds = inputs_embeds.detach()