Shiro2 commited on
Commit
6330de6
·
verified ·
1 Parent(s): 0706b6b

update device_map

Browse files
Files changed (1) hide show
  1. modeling_cocom.py +7 -1
modeling_cocom.py CHANGED
@@ -61,7 +61,7 @@ class COCOMConfig(PretrainedConfig):
61
 
62
  model_type = "COCOM"
63
  def __init__(self,
64
- decoder_model_name="google-t5/t5-base",
65
  quantization = 'no',
66
  generation_top_k = 1,
67
  sep = False,
@@ -72,6 +72,7 @@ class COCOMConfig(PretrainedConfig):
72
  training_form="both",
73
  lora_r=16,
74
  attn_implementation="eager",
 
75
  **kwargs):
76
  super().__init__(**kwargs)
77
 
@@ -86,6 +87,7 @@ class COCOMConfig(PretrainedConfig):
86
  self.training_form = training_form # training form, could be compressor: training only comprssor; both:
87
  self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
88
  self.attn_implementation = attn_implementation
 
89
 
90
  class COCOM(PreTrainedModel):
91
  config_class = COCOMConfig
@@ -100,6 +102,7 @@ class COCOM(PreTrainedModel):
100
  torch_dtype=torch.float16,
101
  attn_implementation=attn_impl,
102
  low_cpu_mem_usage = True,
 
103
  )
104
  elif cfg.quantization == "int4":
105
  quant_config = BitsAndBytesConfig(
@@ -116,6 +119,7 @@ class COCOM(PreTrainedModel):
116
  resume_download=True,
117
  low_cpu_mem_usage = True,
118
  trust_remote_code=True,
 
119
  )
120
  elif cfg.quantization == "int8":
121
  quant_config = BitsAndBytesConfig(
@@ -132,6 +136,7 @@ class COCOM(PreTrainedModel):
132
  resume_download=True,
133
  low_cpu_mem_usage = True,
134
  trust_remote_code=True,
 
135
  )
136
  else:
137
  raise NotImplementedError()
@@ -237,6 +242,7 @@ class COCOM(PreTrainedModel):
237
 
238
  # Perform compression with gradient tracking
239
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
 
240
  # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
241
  if (self.training_form == "compressor") and (self.compr is None):
242
  inputs_embeds = inputs_embeds.detach()
 
61
 
62
  model_type = "COCOM"
63
  def __init__(self,
64
+ decoder_model_name="meta-llama/Llama-2-7b-chat-hf",
65
  quantization = 'no',
66
  generation_top_k = 1,
67
  sep = False,
 
72
  training_form="both",
73
  lora_r=16,
74
  attn_implementation="eager",
75
+ device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
78
 
 
87
  self.training_form = training_form # training form, could be compressor: training only comprssor; both:
88
  self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
89
  self.attn_implementation = attn_implementation
90
+ self.device_map = device_map
91
 
92
  class COCOM(PreTrainedModel):
93
  config_class = COCOMConfig
 
102
  torch_dtype=torch.float16,
103
  attn_implementation=attn_impl,
104
  low_cpu_mem_usage = True,
105
+ device_map =cfg.device_map
106
  )
107
  elif cfg.quantization == "int4":
108
  quant_config = BitsAndBytesConfig(
 
119
  resume_download=True,
120
  low_cpu_mem_usage = True,
121
  trust_remote_code=True,
122
+ device_map =cfg.device_map
123
  )
124
  elif cfg.quantization == "int8":
125
  quant_config = BitsAndBytesConfig(
 
136
  resume_download=True,
137
  low_cpu_mem_usage = True,
138
  trust_remote_code=True,
139
+ device_map =cfg.device_map
140
  )
141
  else:
142
  raise NotImplementedError()
 
242
 
243
  # Perform compression with gradient tracking
244
  inputs_embeds = self.compress_and_replace_emb(enc_input_ids, enc_attention_mask, dec_input_ids)
245
+
246
  # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
247
  if (self.training_form == "compressor") and (self.compr is None):
248
  inputs_embeds = inputs_embeds.detach()