Files changed (1) hide show
  1. modeling_cocom.py +9 -7
modeling_cocom.py CHANGED
@@ -72,6 +72,7 @@ class COCOMConfig(PretrainedConfig):
72
  training_form="both",
73
  lora_r=16,
74
  attn_implementation="eager",
 
75
  **kwargs):
76
  super().__init__(**kwargs)
77
 
@@ -86,6 +87,7 @@ class COCOMConfig(PretrainedConfig):
86
  self.training_form = training_form # training form, could be compressor: training only comprssor; both:
87
  self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
88
  self.attn_implementation = attn_implementation
 
89
 
90
  class COCOM(PreTrainedModel):
91
  config_class = COCOMConfig
@@ -100,7 +102,7 @@ class COCOM(PreTrainedModel):
100
  torch_dtype=torch.float16,
101
  attn_implementation=attn_impl,
102
  low_cpu_mem_usage = True,
103
- device_map='auto'
104
  )
105
  elif cfg.quantization == "int4":
106
  quant_config = BitsAndBytesConfig(
@@ -117,7 +119,7 @@ class COCOM(PreTrainedModel):
117
  resume_download=True,
118
  low_cpu_mem_usage = True,
119
  trust_remote_code=True,
120
- device_map='auto'
121
  )
122
  elif cfg.quantization == "int8":
123
  quant_config = BitsAndBytesConfig(
@@ -134,7 +136,7 @@ class COCOM(PreTrainedModel):
134
  resume_download=True,
135
  low_cpu_mem_usage = True,
136
  trust_remote_code=True,
137
- device_map='auto'
138
  )
139
  else:
140
  raise NotImplementedError()
@@ -300,10 +302,10 @@ class COCOM(PreTrainedModel):
300
 
301
  # generate
302
  model_input = {
303
- 'enc_input_ids': enc_input['input_ids'],
304
- 'enc_attention_mask': enc_input['attention_mask'],
305
- 'dec_input_ids': inp_dec['input_ids'],
306
- 'dec_attention_mask': inp_dec['attention_mask']
307
  }
308
 
309
  return self.generate(model_input, max_new_tokens)
 
72
  training_form="both",
73
  lora_r=16,
74
  attn_implementation="eager",
75
+ device_map = "cuda",
76
  **kwargs):
77
  super().__init__(**kwargs)
78
 
 
87
  self.training_form = training_form # training form, could be compressor: training only comprssor; both:
88
  self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
89
  self.attn_implementation = attn_implementation
90
+ self.device_map = device_map
91
 
92
  class COCOM(PreTrainedModel):
93
  config_class = COCOMConfig
 
102
  torch_dtype=torch.float16,
103
  attn_implementation=attn_impl,
104
  low_cpu_mem_usage = True,
105
+ device_map =cfg.device_map
106
  )
107
  elif cfg.quantization == "int4":
108
  quant_config = BitsAndBytesConfig(
 
119
  resume_download=True,
120
  low_cpu_mem_usage = True,
121
  trust_remote_code=True,
122
+ device_map =cfg.device_map
123
  )
124
  elif cfg.quantization == "int8":
125
  quant_config = BitsAndBytesConfig(
 
136
  resume_download=True,
137
  low_cpu_mem_usage = True,
138
  trust_remote_code=True,
139
+ device_map =cfg.device_map
140
  )
141
  else:
142
  raise NotImplementedError()
 
302
 
303
  # generate
304
  model_input = {
305
+ 'enc_input_ids': enc_input['input_ids'].to(self.decoder.device),
306
+ 'enc_attention_mask': enc_input['attention_mask'].to(self.decoder.device),
307
+ 'dec_input_ids': inp_dec['input_ids'].to(self.decoder.device),
308
+ 'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
309
  }
310
 
311
  return self.generate(model_input, max_new_tokens)