MekkCyber commited on
Commit
364af2c
·
1 Parent(s): 92eb715

change dtype

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -25,7 +25,7 @@ def check_model_exists(oauth_token: gr.OAuthToken | None, username, model_name,
25
  if quantized_model_name :
26
  repo_name = f"{username}/{quantized_model_name}"
27
  else :
28
- repo_name = f"{username}/{model_name.split('/')[-1]}-BNB-INT4"
29
 
30
  if repo_name in model_names:
31
  return f"Model '{repo_name}' already exists in your repository."
@@ -83,7 +83,7 @@ def quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, qua
83
  bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4],
84
  )
85
 
86
- model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token)
87
  for _ , module in model.named_modules():
88
  if isinstance(module, Linear4bit):
89
  module.to("cuda")
@@ -99,7 +99,7 @@ def save_model(model, model_name, quant_type_4, double_quant_4, compute_type_4,
99
  if quantized_model_name :
100
  repo_name = f"{username}/{quantized_model_name}"
101
  else :
102
- repo_name = f"{username}/{model_name.split('/')[-1]}-BNB-INT4"
103
 
104
 
105
  model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4)
@@ -425,5 +425,4 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
425
 
426
  if __name__ == "__main__":
427
  demo.launch(share=True)
428
- # Launch the app
429
- # demo.launch(share=True, debug=True)
 
25
  if quantized_model_name :
26
  repo_name = f"{username}/{quantized_model_name}"
27
  else :
28
+ repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
29
 
30
  if repo_name in model_names:
31
  return f"Model '{repo_name}' already exists in your repository."
 
83
  bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4],
84
  )
85
 
86
+ model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16)
87
  for _ , module in model.named_modules():
88
  if isinstance(module, Linear4bit):
89
  module.to("cuda")
 
99
  if quantized_model_name :
100
  repo_name = f"{username}/{quantized_model_name}"
101
  else :
102
+ repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
103
 
104
 
105
  model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4)
 
425
 
426
  if __name__ == "__main__":
427
  demo.launch(share=True)
428
+