MekkCyber commited on
Commit
b5887d5
Β·
1 Parent(s): 931ff17
Files changed (3) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +414 -120
  3. requirements.txt +1 -1
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -2,16 +2,31 @@ import gradio as gr
2
  import torch
3
  from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
4
  import tempfile
5
- from huggingface_hub import HfApi
6
  from huggingface_hub import list_models
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
  from packaging import version
9
  import os
10
- import spaces
 
 
 
 
 
11
 
12
  MAP_QUANT_TYPE_TO_NAME = {
13
- "int4_weight_only": "int4wo", "int8_weight_only": "int8wo", "int8_dynamic_activation_int8_weight": "int8da8w"
 
 
 
14
  }
 
 
 
 
 
 
 
15
 
16
  def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
17
  # ^ expect a gr.OAuthProfile object as input to get the user's profile
@@ -20,19 +35,29 @@ def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) ->
20
  return "Hello !"
21
  return f"Hello {profile.name} !"
22
 
23
- def check_model_exists(oauth_token: gr.OAuthToken | None, username, quantization_type, group_size, model_name, quantized_model_name):
 
 
 
 
 
 
 
 
24
  """Check if a model exists in the user's Hugging Face repository."""
25
  try:
26
  models = list_models(author=username, token=oauth_token.token)
27
  model_names = [model.id for model in models]
28
- if quantized_model_name :
29
  repo_name = f"{username}/{quantized_model_name}"
30
- else :
31
- if quantization_type == "int4_weight_only" :
32
- repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
33
- else :
34
- repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
35
-
 
 
36
  if repo_name in model_names:
37
  return f"Model '{repo_name}' already exists in your repository."
38
  else:
@@ -40,62 +65,160 @@ def check_model_exists(oauth_token: gr.OAuthToken | None, username, quantization
40
  except Exception as e:
41
  return f"Error checking model existence: {str(e)}"
42
 
 
43
  def create_model_card(model_name, quantization_type, group_size):
44
- model_card = f"""---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  base_model:
46
- - {model_name}
47
- ---
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # {model_name} (Quantized)
50
 
51
  ## Description
52
- This model is a quantized version of the original model `{model_name}`. It has been quantized using {quantization_type} quantization with torchao.
 
 
53
 
54
  ## Quantization Details
55
  - **Quantization Type**: {quantization_type}
56
- - **Group Size**: {group_size if quantization_type == "int4_weight_only" else None}
57
 
58
- ## Usage
59
- You can use this model in your applications by loading it directly from the Hugging Face Hub:
60
 
61
- ```python
62
- from transformers import AutoModel
63
 
64
- model = AutoModel.from_pretrained("{model_name}")"""
65
-
 
66
  return model_card
67
 
68
- def load_model(model_name, quantization_config, auth_token) :
69
- return AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token)
70
-
71
- def load_model_cpu(model_name, quantization_config, auth_token) :
72
- return AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, use_auth_token=auth_token.token)
73
 
74
- def quantize_model(model_name, quantization_type, group_size=128, auth_token=None, username=None):
 
 
75
  print(f"Quantizing model: {quantization_type}")
76
- if quantization_type == "int4_weight_only" :
 
 
 
77
  quantization_config = TorchAoConfig(quantization_type, group_size=group_size)
78
- else :
79
  quantization_config = TorchAoConfig(quantization_type)
80
- model = load_model(model_name, quantization_config=quantization_config, auth_token=auth_token)
 
 
 
 
 
 
81
 
82
  return model
83
 
84
- def save_model(model, model_name, quantization_type, group_size=128, username=None, auth_token=None, quantized_model_name=None):
 
 
 
 
 
 
 
 
 
85
  print("Saving quantized model")
86
  with tempfile.TemporaryDirectory() as tmpdirname:
 
 
 
 
 
87
 
 
 
 
 
88
 
89
- model.save_pretrained(tmpdirname, safe_serialization=False, use_auth_token=auth_token.token)
90
- if quantized_model_name :
91
  repo_name = f"{username}/{quantized_model_name}"
92
- else :
93
- if quantization_type == "int4_weight_only" :
94
- repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
95
- else :
96
- repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
 
 
 
97
 
98
- model_card = create_model_card(repo_name, quantization_type, group_size)
99
  with open(os.path.join(tmpdirname, "README.md"), "w") as f:
100
  f.write(model_card)
101
  # Push to Hub
@@ -106,130 +229,301 @@ def save_model(model, model_name, quantization_type, group_size=128, username=No
106
  repo_id=repo_name,
107
  repo_type="model",
108
  )
109
- return f'<h1> πŸ€— DONE</h1><br/>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a>'
110
 
111
- def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quantization_type, group_size, quantized_model_name):
112
- if oauth_token is None :
113
- return "Error : Please Sign In to your HuggingFace account to use the quantizer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if not profile:
115
- return "Error: Please Sign In to your HuggingFace account to use the quantizer"
116
- exists_message = check_model_exists(oauth_token, profile.username, quantization_type, group_size, model_name, quantized_model_name)
117
- if exists_message :
118
- return exists_message
119
- if quantization_type == "int4_weight_only" :
120
- return "int4_weight_only not supported on cpu"
121
- if not group_size.isdigit() :
122
- return "group_size must be a number"
123
-
124
- group_size = int(group_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  try:
127
- quantized_model = quantize_model(model_name, quantization_type, group_size, oauth_token, profile.username)
128
- return save_model(quantized_model, model_name, quantization_type, group_size, profile.username, oauth_token, quantized_model_name)
129
- except Exception as e :
130
- return e
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
- css="""/* Custom CSS to allow scrolling */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  .gradio-container {overflow-y: auto;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  """
136
- with gr.Blocks(theme=gr.themes.Ocean(), css=css) as app:
 
 
137
  gr.Markdown(
138
  """
139
- # πŸ€— LLM Model TorchAO Quantization App
140
 
141
  Quantize your favorite Hugging Face models using TorchAO and save them to your profile!
 
 
142
  """
143
  )
144
 
145
  gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
146
 
147
  m1 = gr.Markdown()
148
- app.load(hello, inputs=None, outputs=m1)
149
-
150
-
151
- radio = gr.Radio(["show", "hide"], label="Show Instructions", value="hide")
152
- instructions = gr.Markdown(
153
- """
154
- ## Instructions
155
- 1. Login to your HuggingFace account
156
- 2. Enter the name of the Hugging Face LLM model you want to quantize (Make sure you have access to it)
157
- 3. Choose the quantization type.
158
- 4. Optionally, specify the group size.
159
- 5. Optionally, choose a custom name for the quantized model
160
- 6. Click "Quantize and Save Model" to start the process.
161
- 7. Once complete, you'll receive a link to the quantized model on Hugging Face.
162
-
163
- Note: This process may take some time depending on the model size and your hardware you can check the container logs to see where are you at in the process!
164
- """,
165
- visible=False
166
- )
167
- def update_visibility(radio):
168
- value = radio
169
- if value == "show":
170
- return gr.Textbox(visible=True)
171
- else:
172
- return gr.Textbox(visible=False)
173
- radio.change(update_visibility, radio, instructions)
174
 
175
  with gr.Row():
176
  with gr.Column():
177
  with gr.Row():
178
  model_name = HuggingfaceHubSearch(
179
- label="Hub Model ID",
180
  placeholder="Search for model id on Huggingface",
181
  search_type="model",
182
- scale=2
183
  )
 
 
184
  with gr.Row():
185
  with gr.Column():
186
  quantization_type = gr.Dropdown(
187
- info="Quantization Type",
188
- choices=["int4_weight_only", "int8_weight_only", "int8_dynamic_activation_int8_weight"],
 
 
 
 
 
189
  value="int8_weight_only",
190
  filterable=False,
191
  show_label=False,
192
  )
193
  group_size = gr.Textbox(
194
- info="Group Size (only for int4_weight_only)",
195
- value=128,
196
  interactive=True,
197
- show_label=False
198
  )
199
  quantized_model_name = gr.Textbox(
200
- info="Model Name (optional : to override default)",
201
  value="",
202
  interactive=True,
203
- show_label=False
204
  )
 
205
  with gr.Column():
206
- quantize_button = gr.Button("Quantize and Save Model", variant="primary")
207
- output_link = gr.Markdown(label="Quantized Model Link", container=True, min_height=40)
208
-
209
-
210
- # Adding CSS styles for the username box
211
- app.css = """
212
- #username-box {
213
- background-color: #f0f8ff; /* Light color */
214
- border-radius: 8px;
215
- padding: 10px;
216
- }
217
- """
218
- app.css = """
219
- .center-button {
220
- display: flex;
221
- justify-content: center;
222
- align-items: center;
223
- margin: 0 auto; /* Center horizontally */
224
- }
225
- """
226
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  quantize_button.click(
228
  fn=quantize_and_save,
229
  inputs=[model_name, quantization_type, group_size, quantized_model_name],
230
- outputs=[output_link]
231
  )
232
 
233
-
234
  # Launch the app
235
- app.launch()
 
2
  import torch
3
  from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
4
  import tempfile
5
+ from huggingface_hub import HfApi, snapshot_download
6
  from huggingface_hub import list_models
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
  from packaging import version
9
  import os
10
+ from torchao.quantization import (
11
+ Int4WeightOnlyConfig,
12
+ Int8WeightOnlyConfig,
13
+ Int8DynamicActivationInt8WeightConfig,
14
+ Float8WeightOnlyConfig,
15
+ )
16
 
17
  MAP_QUANT_TYPE_TO_NAME = {
18
+ "int4_weight_only": "int4wo",
19
+ "int8_weight_only": "int8wo",
20
+ "int8_dynamic_activation_int8_weight": "int8da8w",
21
+ "autoquant": "autoquant",
22
  }
23
+ MAP_QUANT_TYPE_TO_CONFIG = {
24
+ "int4_weight_only": Int4WeightOnlyConfig,
25
+ "int8_weight_only": Int8WeightOnlyConfig,
26
+ "int8_dynamic_activation_int8_weight": Int8DynamicActivationInt8WeightConfig,
27
+ "float8_weight_only": Float8WeightOnlyConfig,
28
+ }
29
+
30
 
31
  def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
32
  # ^ expect a gr.OAuthProfile object as input to get the user's profile
 
35
  return "Hello !"
36
  return f"Hello {profile.name} !"
37
 
38
+
39
+ def check_model_exists(
40
+ oauth_token: gr.OAuthToken | None,
41
+ username,
42
+ quantization_type,
43
+ group_size,
44
+ model_name,
45
+ quantized_model_name,
46
+ ):
47
  """Check if a model exists in the user's Hugging Face repository."""
48
  try:
49
  models = list_models(author=username, token=oauth_token.token)
50
  model_names = [model.id for model in models]
51
+ if quantized_model_name:
52
  repo_name = f"{username}/{quantized_model_name}"
53
+ else:
54
+ if (
55
+ quantization_type == "int4_weight_only"
56
+ or quantization_type == "int8_weight_only"
57
+ ) and (group_size is not None):
58
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
59
+ else:
60
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
61
  if repo_name in model_names:
62
  return f"Model '{repo_name}' already exists in your repository."
63
  else:
 
65
  except Exception as e:
66
  return f"Error checking model existence: {str(e)}"
67
 
68
+
69
  def create_model_card(model_name, quantization_type, group_size):
70
+ # Try to download the original README
71
+ original_readme = ""
72
+ original_yaml_header = ""
73
+ try:
74
+ # Download the README.md file from the original model
75
+ model_path = snapshot_download(
76
+ repo_id=model_name, allow_patterns=["README.md"], repo_type="model"
77
+ )
78
+ readme_path = os.path.join(model_path, "README.md")
79
+
80
+ if os.path.exists(readme_path):
81
+ with open(readme_path, "r", encoding="utf-8") as f:
82
+ content = f.read()
83
+
84
+ if content.startswith("---"):
85
+ parts = content.split("---", 2)
86
+ if len(parts) >= 3:
87
+ original_yaml_header = parts[1]
88
+ original_readme = "---".join(parts[2:])
89
+ else:
90
+ original_readme = content
91
+ else:
92
+ original_readme = content
93
+ except Exception as e:
94
+ print(f"Error reading original README: {str(e)}")
95
+ original_readme = ""
96
+
97
+ # Create new YAML header with base_model field
98
+ yaml_header = f"""---
99
  base_model:
100
+ - {model_name}"""
 
101
 
102
+ # Add any original YAML fields except base_model
103
+ if original_yaml_header:
104
+ in_base_model_section = False
105
+ found_tags = False
106
+ for line in original_yaml_header.strip().split("\n"):
107
+ # Skip if we're in a base_model section that continues to the next line
108
+ if in_base_model_section:
109
+ if (
110
+ line.strip().startswith("-")
111
+ or not line.strip()
112
+ or line.startswith(" ")
113
+ ):
114
+ continue
115
+ else:
116
+ in_base_model_section = False
117
+
118
+ # Check for base_model field
119
+ if line.strip().startswith("base_model:"):
120
+ in_base_model_section = True
121
+ # If base_model has inline value (like "base_model: model_name")
122
+ if ":" in line and len(line.split(":", 1)[1].strip()) > 0:
123
+ in_base_model_section = False
124
+ continue
125
+
126
+ # Check for tags field and add bnb-my-repo
127
+ if line.strip().startswith("tags:"):
128
+ found_tags = True
129
+ yaml_header += f"\n{line}"
130
+ yaml_header += "\n- torchao-my-repo"
131
+ continue
132
+
133
+ yaml_header += f"\n{line}"
134
+
135
+ # If tags field wasn't found, add it
136
+ if not found_tags:
137
+ yaml_header += "\ntags:"
138
+ yaml_header += "\n- torchao-my-repo"
139
+ # Complete the YAML header
140
+ yaml_header += "\n---"
141
+
142
+ # Create the quantization info section
143
+ quant_info = f"""
144
  # {model_name} (Quantized)
145
 
146
  ## Description
147
+ This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}).
148
+
149
+ It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space.
150
 
151
  ## Quantization Details
152
  - **Quantization Type**: {quantization_type}
153
+ - **Group Size**: {group_size}
154
 
155
+ """
 
156
 
157
+ # Combine everything
158
+ model_card = yaml_header + quant_info
159
 
160
+ # Append original README content if available
161
+ if original_readme and not original_readme.isspace():
162
+ model_card += "\n\n# πŸ“„ Original Model Information\n\n" + original_readme
163
  return model_card
164
 
 
 
 
 
 
165
 
166
+ def quantize_model(
167
+ model_name, quantization_type, group_size=128, auth_token=None, username=None
168
+ ):
169
  print(f"Quantizing model: {quantization_type}")
170
+ if (
171
+ quantization_type == "int4_weight_only"
172
+ or quantization_type == "int8_weight_only"
173
+ ):
174
  quantization_config = TorchAoConfig(quantization_type, group_size=group_size)
175
+ else:
176
  quantization_config = TorchAoConfig(quantization_type)
177
+ model = AutoModel.from_pretrained(
178
+ model_name,
179
+ torch_dtype="auto",
180
+ quantization_config=quantization_config,
181
+ device_map="cpu",
182
+ use_auth_token=auth_token.token,
183
+ )
184
 
185
  return model
186
 
187
+
188
+ def save_model(
189
+ model,
190
+ model_name,
191
+ quantization_type,
192
+ group_size=128,
193
+ username=None,
194
+ auth_token=None,
195
+ quantized_model_name=None,
196
+ ):
197
  print("Saving quantized model")
198
  with tempfile.TemporaryDirectory() as tmpdirname:
199
+ # Load and save the tokenizer
200
+ tokenizer = AutoTokenizer.from_pretrained(
201
+ model_name, use_auth_token=auth_token.token
202
+ )
203
+ tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)
204
 
205
+ # Save the model
206
+ model.save_pretrained(
207
+ tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
208
+ )
209
 
210
+ if quantized_model_name:
 
211
  repo_name = f"{username}/{quantized_model_name}"
212
+ else:
213
+ if (
214
+ quantization_type == "int4_weight_only"
215
+ or quantization_type == "int8_weight_only"
216
+ ) and (group_size is not None):
217
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
218
+ else:
219
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
220
 
221
+ model_card = create_model_card(model_name, quantization_type, group_size)
222
  with open(os.path.join(tmpdirname, "README.md"), "w") as f:
223
  f.write(model_card)
224
  # Push to Hub
 
229
  repo_id=repo_name,
230
  repo_type="model",
231
  )
 
232
 
233
+ import io
234
+ from contextlib import redirect_stdout
235
+ import html
236
+
237
+ # Capture the model architecture string
238
+ f = io.StringIO()
239
+ with redirect_stdout(f):
240
+ print(model)
241
+ model_architecture_str = f.getvalue()
242
+
243
+ # Escape HTML characters and format with line breaks
244
+ model_architecture_str_html = html.escape(model_architecture_str).replace(
245
+ "\n", "<br/>"
246
+ )
247
+
248
+ # Format it for display in markdown with proper styling
249
+ model_architecture_info = f"""
250
+ <div class="model-architecture-container" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
251
+ <h3 style="margin-top: 0; color: #2E7D32;">πŸ“‹ Model Architecture</h3>
252
+ <div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;">
253
+ <div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div>
254
+ </div>
255
+ </div>
256
+ """
257
+
258
+ repo_link = f"""
259
+ <div class="repo-link" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
260
+ <h3 style="margin-top: 0; color: #2E7D32;">πŸ”— Repository Link</h3>
261
+ <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a></p>
262
+ </div>
263
+ """
264
+ return (
265
+ f"<h1>πŸŽ‰ Quantization Completed</h1><br/>{repo_link}{model_architecture_info}"
266
+ )
267
+
268
+
269
+ def quantize_and_save(
270
+ profile: gr.OAuthProfile | None,
271
+ oauth_token: gr.OAuthToken | None,
272
+ model_name,
273
+ quantization_type,
274
+ group_size,
275
+ quantized_model_name,
276
+ ):
277
+ if oauth_token is None:
278
+ return """
279
+ <div class="error-box">
280
+ <h3>❌ Authentication Error</h3>
281
+ <p>Please sign in to your HuggingFace account to use the quantizer.</p>
282
+ </div>
283
+ """
284
  if not profile:
285
+ return """
286
+ <div class="error-box">
287
+ <h3>❌ Authentication Error</h3>
288
+ <p>Please sign in to your HuggingFace account to use the quantizer.</p>
289
+ </div>
290
+ """
291
+ if not group_size.isdigit():
292
+ if group_size != "":
293
+ return """
294
+ <div class="error-box">
295
+ <h3>❌ Group Size Error</h3>
296
+ <p>Group Size is a number for int4_weight_only and int8_weight_only or empty for int8_weight_only</p>
297
+ </div>
298
+ """
299
+
300
+ if group_size and group_size.strip():
301
+ group_size = int(group_size)
302
+ else:
303
+ group_size = None
304
+
305
+ exists_message = check_model_exists(
306
+ oauth_token,
307
+ profile.username,
308
+ quantization_type,
309
+ group_size,
310
+ model_name,
311
+ quantized_model_name,
312
+ )
313
+ if exists_message:
314
+ return f"""
315
+ <div class="warning-box">
316
+ <h3>⚠️ Model Already Exists</h3>
317
+ <p>{exists_message}</p>
318
+ </div>
319
+ """
320
+ # if quantization_type == "int4_weight_only" :
321
+ # return "int4_weight_only not supported on cpu"
322
 
323
  try:
324
+ quantized_model = quantize_model(
325
+ model_name, quantization_type, group_size, oauth_token, profile.username
326
+ )
327
+ return save_model(
328
+ quantized_model,
329
+ model_name,
330
+ quantization_type,
331
+ group_size,
332
+ profile.username,
333
+ oauth_token,
334
+ quantized_model_name,
335
+ )
336
+ except Exception as e:
337
+ return str(e)
338
 
339
 
340
+ def get_model_size(model):
341
+ """
342
+ Calculate the size of a PyTorch model in gigabytes.
343
+
344
+ Args:
345
+ model: PyTorch model
346
+
347
+ Returns:
348
+ float: Size of the model in GB
349
+ """
350
+ # Get model state dict
351
+ state_dict = model.state_dict()
352
+
353
+ # Calculate total size in bytes
354
+ total_size = 0
355
+ for param in state_dict.values():
356
+ # Calculate bytes for each parameter
357
+ total_size += param.nelement() * param.element_size()
358
+
359
+ # Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes)
360
+ size_gb = total_size / (1024**3)
361
+ size_gb = round(size_gb, 2)
362
+
363
+ return size_gb
364
+
365
+
366
+ # Add enhanced CSS styling
367
+ css = """
368
+ /* Custom CSS for enhanced UI */
369
  .gradio-container {overflow-y: auto;}
370
+
371
+ /* Fix alignment for radio buttons and dropdowns */
372
+ .gradio-radio, .gradio-dropdown {
373
+ display: flex !important;
374
+ align-items: center !important;
375
+ margin: 10px 0 !important;
376
+ }
377
+
378
+ /* Consistent spacing and alignment */
379
+ .gradio-dropdown, .gradio-textbox, .gradio-radio {
380
+ margin-bottom: 12px !important;
381
+ width: 100% !important;
382
+ }
383
+
384
+ /* Quantize button styling with glow effect */
385
+ button[variant="primary"] {
386
+ background: linear-gradient(135deg, #3B82F6, #10B981) !important;
387
+ color: white !important;
388
+ padding: 16px 32px !important;
389
+ font-size: 1.1rem !important;
390
+ font-weight: 700 !important;
391
+ border: none !important;
392
+ border-radius: 12px !important;
393
+ box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
394
+ transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
395
+ position: relative;
396
+ overflow: hidden;
397
+ animation: glow 1.5s ease-in-out infinite alternate;
398
+ }
399
+
400
+ button[variant="primary"]::before {
401
+ content: "✨ ";
402
+ }
403
+
404
+ button[variant="primary"]:hover {
405
+ transform: translateY(-5px) scale(1.05) !important;
406
+ box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important;
407
+ }
408
+
409
+ @keyframes glow {
410
+ from {
411
+ box-shadow: 0 0 10px rgba(59, 130, 246, 0.5);
412
+ }
413
+ to {
414
+ box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5);
415
+ }
416
+ }
417
+
418
+ /* Login button styling */
419
+ #login-button {
420
+ background: linear-gradient(135deg, #3B82F6, #10B981) !important;
421
+ color: white !important;
422
+ font-weight: 700 !important;
423
+ border: none !important;
424
+ border-radius: 12px !important;
425
+ box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
426
+ transition: all 0.3s ease !important;
427
+ max-width: 300px !important;
428
+ margin: 0 auto !important;
429
+ }
430
  """
431
+
432
+ # Update the main app layout
433
+ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
434
  gr.Markdown(
435
  """
436
+ # πŸ€— TorchAO Model Quantizer ✨
437
 
438
  Quantize your favorite Hugging Face models using TorchAO and save them to your profile!
439
+
440
+ <br/>
441
  """
442
  )
443
 
444
  gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
445
 
446
  m1 = gr.Markdown()
447
+ demo.load(hello, inputs=None, outputs=m1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
  with gr.Row():
450
  with gr.Column():
451
  with gr.Row():
452
  model_name = HuggingfaceHubSearch(
453
+ label="πŸ” Hub Model ID",
454
  placeholder="Search for model id on Huggingface",
455
  search_type="model",
 
456
  )
457
+
458
+ gr.Markdown("""### βš™οΈ Quantization Settings""")
459
  with gr.Row():
460
  with gr.Column():
461
  quantization_type = gr.Dropdown(
462
+ info="Select the Quantization method",
463
+ choices=[
464
+ "int4_weight_only",
465
+ "int8_weight_only",
466
+ "int8_dynamic_activation_int8_weight",
467
+ "autoquant",
468
+ ],
469
  value="int8_weight_only",
470
  filterable=False,
471
  show_label=False,
472
  )
473
  group_size = gr.Textbox(
474
+ info="Group Size (only for int4_weight_only and int8_weight_only)",
475
+ value="128",
476
  interactive=True,
477
+ show_label=False,
478
  )
479
  quantized_model_name = gr.Textbox(
480
+ info="Custom name for your quantized model (optional)",
481
  value="",
482
  interactive=True,
483
+ show_label=False,
484
  )
485
+
486
  with gr.Column():
487
+ quantize_button = gr.Button(
488
+ "πŸš€ Quantize and Push to Hub", variant="primary"
489
+ )
490
+ output_link = gr.Markdown(
491
+ label="πŸ”— Quantized Model Info", container=True, min_height=200
492
+ )
493
+
494
+ # Add information section
495
+ with gr.Accordion("πŸ“š About TorchAO Quantization", open=True):
496
+ gr.Markdown(
497
+ """
498
+ ## πŸ“ Quantization Options
499
+
500
+ ### Quantization Types
501
+ - **int4_weight_only**: 4-bit weight-only quantization
502
+ - **int8_weight_only**: 8-bit weight-only quantization
503
+ - **int8_dynamic_activation_int8_weight**: 8-bit quantization for both weights and activations
504
+
505
+ ### Group Size
506
+ - Only applicable for int4_weight_only and int8_weight_only quantization
507
+ - Default value is 128
508
+ - Affects the granularity of quantization
509
+
510
+ ## πŸ” How It Works
511
+ 1. Downloads the original model
512
+ 2. Applies TorchAO quantization with your selected settings
513
+ 3. Uploads the quantized model to your HuggingFace account
514
+
515
+ ## πŸ“Š Memory Benefits
516
+ - int4_weight_only can reduce model size by up to 75%
517
+ - int8_weight_only typically reduces size by about 50%
518
+ """
519
+ )
520
+
521
+ # Keep existing click handler
522
  quantize_button.click(
523
  fn=quantize_and_save,
524
  inputs=[model_name, quantization_type, group_size, quantized_model_name],
525
+ outputs=[output_link],
526
  )
527
 
 
528
  # Launch the app
529
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- git+https://github.com/huggingface/transformers.git@main#egg=transformers
2
  accelerate
3
  torchao
4
  huggingface-hub
 
1
+ transformers
2
  accelerate
3
  torchao
4
  huggingface-hub