Tu2003716 commited on
Commit
9af0971
·
verified ·
1 Parent(s): acab43c

Convert bfloat16 to float16

Browse files
Files changed (1) hide show
  1. modeling_cocom.py +7 -7
modeling_cocom.py CHANGED
@@ -14,7 +14,7 @@ class BERT_Compressor(torch.nn.Module):
14
  super().__init__()
15
  # init model
16
  self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
17
- self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.bfloat16)
18
  self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
19
  self.compr_rate = compr_rate # compression rate
20
  self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.
@@ -23,7 +23,7 @@ class BERT_Compressor(torch.nn.Module):
23
  self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
24
  elif self.compressing_mode == 'mean':
25
  self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
26
- self.linear = self.linear.bfloat16()
27
 
28
  def forward(self, input_ids, attention_mask):
29
  # compressing context using BERT
@@ -97,7 +97,7 @@ class COCOM(PreTrainedModel):
97
  if cfg.quantization == "no":
98
  self.decoder = AutoModelForCausalLM.from_pretrained(
99
  cfg.decoder_model_name,
100
- torch_dtype=torch.bfloat16,
101
  attn_implementation=attn_impl,
102
  low_cpu_mem_usage = True,
103
  )
@@ -105,14 +105,14 @@ class COCOM(PreTrainedModel):
105
  quant_config = BitsAndBytesConfig(
106
  load_in_4bit=True,
107
  bnb_4bit_quant_type='nf4',
108
- bnb_4bit_compute_dtype='bfloat16',
109
  low_cpu_mem_usage = True,
110
  )
111
  self.decoder = AutoModelForCausalLM.from_pretrained(
112
  cfg.decoder_model_name,
113
  quantization_config=quant_config,
114
  attn_implementation=attn_impl,
115
- torch_dtype=torch.bfloat16,
116
  resume_download=True,
117
  low_cpu_mem_usage = True,
118
  trust_remote_code=True,
@@ -121,14 +121,14 @@ class COCOM(PreTrainedModel):
121
  quant_config = BitsAndBytesConfig(
122
  load_in_8bit=True,
123
  llm_int8_enable_fp32_cpu_offload=True,
124
- bnb_4bit_compute_dtype='bfloat16',
125
  low_cpu_mem_usage = True,
126
  )
127
  self.decoder = AutoModelForCausalLM.from_pretrained(
128
  cfg.decoder_model_name,
129
  quantization_config=quant_config,
130
  attn_implementation=attn_impl,
131
- torch_dtype=torch.bfloat16,
132
  resume_download=True,
133
  low_cpu_mem_usage = True,
134
  trust_remote_code=True,
 
14
  super().__init__()
15
  # init model
16
  self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
17
+ self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
18
  self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
19
  self.compr_rate = compr_rate # compression rate
20
  self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.
 
23
  self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
24
  elif self.compressing_mode == 'mean':
25
  self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
26
+ self.linear = self.linear.float16()
27
 
28
  def forward(self, input_ids, attention_mask):
29
  # compressing context using BERT
 
97
  if cfg.quantization == "no":
98
  self.decoder = AutoModelForCausalLM.from_pretrained(
99
  cfg.decoder_model_name,
100
+ torch_dtype=torch.float16,
101
  attn_implementation=attn_impl,
102
  low_cpu_mem_usage = True,
103
  )
 
105
  quant_config = BitsAndBytesConfig(
106
  load_in_4bit=True,
107
  bnb_4bit_quant_type='nf4',
108
+ bnb_4bit_compute_dtype='float16',
109
  low_cpu_mem_usage = True,
110
  )
111
  self.decoder = AutoModelForCausalLM.from_pretrained(
112
  cfg.decoder_model_name,
113
  quantization_config=quant_config,
114
  attn_implementation=attn_impl,
115
+ torch_dtype=torch.float16,
116
  resume_download=True,
117
  low_cpu_mem_usage = True,
118
  trust_remote_code=True,
 
121
  quant_config = BitsAndBytesConfig(
122
  load_in_8bit=True,
123
  llm_int8_enable_fp32_cpu_offload=True,
124
+ bnb_4bit_compute_dtype='float16',
125
  low_cpu_mem_usage = True,
126
  )
127
  self.decoder = AutoModelForCausalLM.from_pretrained(
128
  cfg.decoder_model_name,
129
  quantization_config=quant_config,
130
  attn_implementation=attn_impl,
131
+ torch_dtype=torch.float16,
132
  resume_download=True,
133
  low_cpu_mem_usage = True,
134
  trust_remote_code=True,