Upload modeling_cocom.py
Browse files- modeling_cocom.py +5 -1
modeling_cocom.py
CHANGED
@@ -61,7 +61,7 @@ class COCOMConfig(PretrainedConfig):
|
|
61 |
|
62 |
model_type = "COCOM"
|
63 |
def __init__(self,
|
64 |
-
decoder_model_name="
|
65 |
quantization = 'no',
|
66 |
generation_top_k = 1,
|
67 |
sep = False,
|
@@ -100,6 +100,7 @@ class COCOM(PreTrainedModel):
|
|
100 |
torch_dtype=torch.float16,
|
101 |
attn_implementation=attn_impl,
|
102 |
low_cpu_mem_usage = True,
|
|
|
103 |
)
|
104 |
elif cfg.quantization == "int4":
|
105 |
quant_config = BitsAndBytesConfig(
|
@@ -116,6 +117,7 @@ class COCOM(PreTrainedModel):
|
|
116 |
resume_download=True,
|
117 |
low_cpu_mem_usage = True,
|
118 |
trust_remote_code=True,
|
|
|
119 |
)
|
120 |
elif cfg.quantization == "int8":
|
121 |
quant_config = BitsAndBytesConfig(
|
@@ -132,6 +134,7 @@ class COCOM(PreTrainedModel):
|
|
132 |
resume_download=True,
|
133 |
low_cpu_mem_usage = True,
|
134 |
trust_remote_code=True,
|
|
|
135 |
)
|
136 |
else:
|
137 |
raise NotImplementedError()
|
@@ -237,6 +240,7 @@ class COCOM(PreTrainedModel):
|
|
237 |
|
238 |
# Perform compression with gradient tracking
|
239 |
inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
|
|
|
240 |
# if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
|
241 |
if (self.training_form == "compressor") and (self.compr is None):
|
242 |
inputs_embeds = inputs_embeds.detach()
|
|
|
61 |
|
62 |
model_type = "COCOM"
|
63 |
def __init__(self,
|
64 |
+
decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
|
65 |
quantization = 'no',
|
66 |
generation_top_k = 1,
|
67 |
sep = False,
|
|
|
100 |
torch_dtype=torch.float16,
|
101 |
attn_implementation=attn_impl,
|
102 |
low_cpu_mem_usage = True,
|
103 |
+
device_map='auto'
|
104 |
)
|
105 |
elif cfg.quantization == "int4":
|
106 |
quant_config = BitsAndBytesConfig(
|
|
|
117 |
resume_download=True,
|
118 |
low_cpu_mem_usage = True,
|
119 |
trust_remote_code=True,
|
120 |
+
device_map='auto'
|
121 |
)
|
122 |
elif cfg.quantization == "int8":
|
123 |
quant_config = BitsAndBytesConfig(
|
|
|
134 |
resume_download=True,
|
135 |
low_cpu_mem_usage = True,
|
136 |
trust_remote_code=True,
|
137 |
+
device_map='auto'
|
138 |
)
|
139 |
else:
|
140 |
raise NotImplementedError()
|
|
|
240 |
|
241 |
# Perform compression with gradient tracking
|
242 |
inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
|
243 |
+
|
244 |
# if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
|
245 |
if (self.training_form == "compressor") and (self.compr is None):
|
246 |
inputs_embeds = inputs_embeds.detach()
|