temp
Browse files- modeling_cocom.py +9 -7
modeling_cocom.py
CHANGED
@@ -72,6 +72,7 @@ class COCOMConfig(PretrainedConfig):
|
|
72 |
training_form="both",
|
73 |
lora_r=16,
|
74 |
attn_implementation="eager",
|
|
|
75 |
**kwargs):
|
76 |
super().__init__(**kwargs)
|
77 |
|
@@ -86,6 +87,7 @@ class COCOMConfig(PretrainedConfig):
|
|
86 |
self.training_form = training_form # training form, could be compressor: training only comprssor; both:
|
87 |
self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
|
88 |
self.attn_implementation = attn_implementation
|
|
|
89 |
|
90 |
class COCOM(PreTrainedModel):
|
91 |
config_class = COCOMConfig
|
@@ -100,7 +102,7 @@ class COCOM(PreTrainedModel):
|
|
100 |
torch_dtype=torch.float16,
|
101 |
attn_implementation=attn_impl,
|
102 |
low_cpu_mem_usage = True,
|
103 |
-
device_map=
|
104 |
)
|
105 |
elif cfg.quantization == "int4":
|
106 |
quant_config = BitsAndBytesConfig(
|
@@ -117,7 +119,7 @@ class COCOM(PreTrainedModel):
|
|
117 |
resume_download=True,
|
118 |
low_cpu_mem_usage = True,
|
119 |
trust_remote_code=True,
|
120 |
-
device_map=
|
121 |
)
|
122 |
elif cfg.quantization == "int8":
|
123 |
quant_config = BitsAndBytesConfig(
|
@@ -134,7 +136,7 @@ class COCOM(PreTrainedModel):
|
|
134 |
resume_download=True,
|
135 |
low_cpu_mem_usage = True,
|
136 |
trust_remote_code=True,
|
137 |
-
device_map=
|
138 |
)
|
139 |
else:
|
140 |
raise NotImplementedError()
|
@@ -300,10 +302,10 @@ class COCOM(PreTrainedModel):
|
|
300 |
|
301 |
# generate
|
302 |
model_input = {
|
303 |
-
'enc_input_ids': enc_input['input_ids'],
|
304 |
-
'enc_attention_mask': enc_input['attention_mask'],
|
305 |
-
'dec_input_ids': inp_dec['input_ids'],
|
306 |
-
'dec_attention_mask': inp_dec['attention_mask']
|
307 |
}
|
308 |
|
309 |
return self.generate(model_input, max_new_tokens)
|
|
|
72 |
training_form="both",
|
73 |
lora_r=16,
|
74 |
attn_implementation="eager",
|
75 |
+
device_map = "cuda",
|
76 |
**kwargs):
|
77 |
super().__init__(**kwargs)
|
78 |
|
|
|
87 |
self.training_form = training_form # training form, could be compressor: training only comprssor; both:
|
88 |
self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
|
89 |
self.attn_implementation = attn_implementation
|
90 |
+
self.device_map = device_map
|
91 |
|
92 |
class COCOM(PreTrainedModel):
|
93 |
config_class = COCOMConfig
|
|
|
102 |
torch_dtype=torch.float16,
|
103 |
attn_implementation=attn_impl,
|
104 |
low_cpu_mem_usage = True,
|
105 |
+
device_map =cfg.device_map
|
106 |
)
|
107 |
elif cfg.quantization == "int4":
|
108 |
quant_config = BitsAndBytesConfig(
|
|
|
119 |
resume_download=True,
|
120 |
low_cpu_mem_usage = True,
|
121 |
trust_remote_code=True,
|
122 |
+
device_map =cfg.device_map
|
123 |
)
|
124 |
elif cfg.quantization == "int8":
|
125 |
quant_config = BitsAndBytesConfig(
|
|
|
136 |
resume_download=True,
|
137 |
low_cpu_mem_usage = True,
|
138 |
trust_remote_code=True,
|
139 |
+
device_map =cfg.device_map
|
140 |
)
|
141 |
else:
|
142 |
raise NotImplementedError()
|
|
|
302 |
|
303 |
# generate
|
304 |
model_input = {
|
305 |
+
'enc_input_ids': enc_input['input_ids'].to(self.decoder.device),
|
306 |
+
'enc_attention_mask': enc_input['attention_mask'].to(self.decoder.device),
|
307 |
+
'dec_input_ids': inp_dec['input_ids'].to(self.decoder.device),
|
308 |
+
'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
|
309 |
}
|
310 |
|
311 |
return self.generate(model_input, max_new_tokens)
|