update device_map
Browse files- modeling_cocom.py +7 -1
modeling_cocom.py
CHANGED
@@ -61,7 +61,7 @@ class COCOMConfig(PretrainedConfig):
|
|
61 |
|
62 |
model_type = "COCOM"
|
63 |
def __init__(self,
|
64 |
-
decoder_model_name="
|
65 |
quantization = 'no',
|
66 |
generation_top_k = 1,
|
67 |
sep = False,
|
@@ -72,6 +72,7 @@ class COCOMConfig(PretrainedConfig):
|
|
72 |
training_form="both",
|
73 |
lora_r=16,
|
74 |
attn_implementation="eager",
|
|
|
75 |
**kwargs):
|
76 |
super().__init__(**kwargs)
|
77 |
|
@@ -86,6 +87,7 @@ class COCOMConfig(PretrainedConfig):
|
|
86 |
self.training_form = training_form # training form, could be compressor: training only comprssor; both:
|
87 |
self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
|
88 |
self.attn_implementation = attn_implementation
|
|
|
89 |
|
90 |
class COCOM(PreTrainedModel):
|
91 |
config_class = COCOMConfig
|
@@ -100,6 +102,7 @@ class COCOM(PreTrainedModel):
|
|
100 |
torch_dtype=torch.float16,
|
101 |
attn_implementation=attn_impl,
|
102 |
low_cpu_mem_usage = True,
|
|
|
103 |
)
|
104 |
elif cfg.quantization == "int4":
|
105 |
quant_config = BitsAndBytesConfig(
|
@@ -116,6 +119,7 @@ class COCOM(PreTrainedModel):
|
|
116 |
resume_download=True,
|
117 |
low_cpu_mem_usage = True,
|
118 |
trust_remote_code=True,
|
|
|
119 |
)
|
120 |
elif cfg.quantization == "int8":
|
121 |
quant_config = BitsAndBytesConfig(
|
@@ -132,6 +136,7 @@ class COCOM(PreTrainedModel):
|
|
132 |
resume_download=True,
|
133 |
low_cpu_mem_usage = True,
|
134 |
trust_remote_code=True,
|
|
|
135 |
)
|
136 |
else:
|
137 |
raise NotImplementedError()
|
@@ -237,6 +242,7 @@ class COCOM(PreTrainedModel):
|
|
237 |
|
238 |
# Perform compression with gradient tracking
|
239 |
inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
|
|
|
240 |
# if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
|
241 |
if (self.training_form == "compressor") and (self.compr is None):
|
242 |
inputs_embeds = inputs_embeds.detach()
|
|
|
61 |
|
62 |
model_type = "COCOM"
|
63 |
def __init__(self,
|
64 |
+
decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
|
65 |
quantization = 'no',
|
66 |
generation_top_k = 1,
|
67 |
sep = False,
|
|
|
72 |
training_form="both",
|
73 |
lora_r=16,
|
74 |
attn_implementation="eager",
|
75 |
+
device_map = "cuda",
|
76 |
**kwargs):
|
77 |
super().__init__(**kwargs)
|
78 |
|
|
|
87 |
self.training_form = training_form # training form, could be compressor: training only comprssor; both:
|
88 |
self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
|
89 |
self.attn_implementation = attn_implementation
|
90 |
+
self.device_map = device_map
|
91 |
|
92 |
class COCOM(PreTrainedModel):
|
93 |
config_class = COCOMConfig
|
|
|
102 |
torch_dtype=torch.float16,
|
103 |
attn_implementation=attn_impl,
|
104 |
low_cpu_mem_usage = True,
|
105 |
+
device_map =cfg.device_map
|
106 |
)
|
107 |
elif cfg.quantization == "int4":
|
108 |
quant_config = BitsAndBytesConfig(
|
|
|
119 |
resume_download=True,
|
120 |
low_cpu_mem_usage = True,
|
121 |
trust_remote_code=True,
|
122 |
+
device_map =cfg.device_map
|
123 |
)
|
124 |
elif cfg.quantization == "int8":
|
125 |
quant_config = BitsAndBytesConfig(
|
|
|
136 |
resume_download=True,
|
137 |
low_cpu_mem_usage = True,
|
138 |
trust_remote_code=True,
|
139 |
+
device_map =cfg.device_map
|
140 |
)
|
141 |
else:
|
142 |
raise NotImplementedError()
|
|
|
242 |
|
243 |
# Perform compression with gradient tracking
|
244 |
inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
|
245 |
+
|
246 |
# if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
|
247 |
if (self.training_form == "compressor") and (self.compr is None):
|
248 |
inputs_embeds = inputs_embeds.detach()
|