Convert bfloat16 to float16
Browse files- modeling_cocom.py +7 -7
modeling_cocom.py
CHANGED
@@ -14,7 +14,7 @@ class BERT_Compressor(torch.nn.Module):
|
|
14 |
super().__init__()
|
15 |
# init model
|
16 |
self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
|
17 |
-
self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.
|
18 |
self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
|
19 |
self.compr_rate = compr_rate # compression rate
|
20 |
self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.
|
@@ -23,7 +23,7 @@ class BERT_Compressor(torch.nn.Module):
|
|
23 |
self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
|
24 |
elif self.compressing_mode == 'mean':
|
25 |
self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
|
26 |
-
self.linear = self.linear.
|
27 |
|
28 |
def forward(self, input_ids, attention_mask):
|
29 |
# compressing context using BERT
|
@@ -97,7 +97,7 @@ class COCOM(PreTrainedModel):
|
|
97 |
if cfg.quantization == "no":
|
98 |
self.decoder = AutoModelForCausalLM.from_pretrained(
|
99 |
cfg.decoder_model_name,
|
100 |
-
torch_dtype=torch.
|
101 |
attn_implementation=attn_impl,
|
102 |
low_cpu_mem_usage = True,
|
103 |
)
|
@@ -105,14 +105,14 @@ class COCOM(PreTrainedModel):
|
|
105 |
quant_config = BitsAndBytesConfig(
|
106 |
load_in_4bit=True,
|
107 |
bnb_4bit_quant_type='nf4',
|
108 |
-
bnb_4bit_compute_dtype='
|
109 |
low_cpu_mem_usage = True,
|
110 |
)
|
111 |
self.decoder = AutoModelForCausalLM.from_pretrained(
|
112 |
cfg.decoder_model_name,
|
113 |
quantization_config=quant_config,
|
114 |
attn_implementation=attn_impl,
|
115 |
-
torch_dtype=torch.
|
116 |
resume_download=True,
|
117 |
low_cpu_mem_usage = True,
|
118 |
trust_remote_code=True,
|
@@ -121,14 +121,14 @@ class COCOM(PreTrainedModel):
|
|
121 |
quant_config = BitsAndBytesConfig(
|
122 |
load_in_8bit=True,
|
123 |
llm_int8_enable_fp32_cpu_offload=True,
|
124 |
-
bnb_4bit_compute_dtype='
|
125 |
low_cpu_mem_usage = True,
|
126 |
)
|
127 |
self.decoder = AutoModelForCausalLM.from_pretrained(
|
128 |
cfg.decoder_model_name,
|
129 |
quantization_config=quant_config,
|
130 |
attn_implementation=attn_impl,
|
131 |
-
torch_dtype=torch.
|
132 |
resume_download=True,
|
133 |
low_cpu_mem_usage = True,
|
134 |
trust_remote_code=True,
|
|
|
14 |
super().__init__()
|
15 |
# init model
|
16 |
self.model_name = compr_model_name # base model name of BERT; example: bert-base-ucased
|
17 |
+
self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.float16)
|
18 |
self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
|
19 |
self.compr_rate = compr_rate # compression rate
|
20 |
self.compressing_mode = compr_linear_type # linear layer type, could be either concat or mean.
|
|
|
23 |
self.linear = torch.nn.Linear(self.model.config.hidden_size*self.compr_rate, decoder_hidden_size)
|
24 |
elif self.compressing_mode == 'mean':
|
25 |
self.linear = torch.nn.Linear(self.model.config.hidden_size, decoder_hidden_size)
|
26 |
+
self.linear = self.linear.float16()
|
27 |
|
28 |
def forward(self, input_ids, attention_mask):
|
29 |
# compressing context using BERT
|
|
|
97 |
if cfg.quantization == "no":
|
98 |
self.decoder = AutoModelForCausalLM.from_pretrained(
|
99 |
cfg.decoder_model_name,
|
100 |
+
torch_dtype=torch.float16,
|
101 |
attn_implementation=attn_impl,
|
102 |
low_cpu_mem_usage = True,
|
103 |
)
|
|
|
105 |
quant_config = BitsAndBytesConfig(
|
106 |
load_in_4bit=True,
|
107 |
bnb_4bit_quant_type='nf4',
|
108 |
+
bnb_4bit_compute_dtype='float16',
|
109 |
low_cpu_mem_usage = True,
|
110 |
)
|
111 |
self.decoder = AutoModelForCausalLM.from_pretrained(
|
112 |
cfg.decoder_model_name,
|
113 |
quantization_config=quant_config,
|
114 |
attn_implementation=attn_impl,
|
115 |
+
torch_dtype=torch.float16,
|
116 |
resume_download=True,
|
117 |
low_cpu_mem_usage = True,
|
118 |
trust_remote_code=True,
|
|
|
121 |
quant_config = BitsAndBytesConfig(
|
122 |
load_in_8bit=True,
|
123 |
llm_int8_enable_fp32_cpu_offload=True,
|
124 |
+
bnb_4bit_compute_dtype='float16',
|
125 |
low_cpu_mem_usage = True,
|
126 |
)
|
127 |
self.decoder = AutoModelForCausalLM.from_pretrained(
|
128 |
cfg.decoder_model_name,
|
129 |
quantization_config=quant_config,
|
130 |
attn_implementation=attn_impl,
|
131 |
+
torch_dtype=torch.float16,
|
132 |
resume_download=True,
|
133 |
low_cpu_mem_usage = True,
|
134 |
trust_remote_code=True,
|