MekkCyber commited on
Commit
70dd883
·
1 Parent(s): b5887d5

updating to use AoBaseConfig

Browse files
Files changed (1) hide show
  1. app.py +65 -19
app.py CHANGED
@@ -12,12 +12,15 @@ from torchao.quantization import (
12
  Int8WeightOnlyConfig,
13
  Int8DynamicActivationInt8WeightConfig,
14
  Float8WeightOnlyConfig,
 
15
  )
16
 
17
  MAP_QUANT_TYPE_TO_NAME = {
18
  "int4_weight_only": "int4wo",
19
  "int8_weight_only": "int8wo",
20
- "int8_dynamic_activation_int8_weight": "int8da8w",
 
 
21
  "autoquant": "autoquant",
22
  }
23
  MAP_QUANT_TYPE_TO_CONFIG = {
@@ -25,6 +28,7 @@ MAP_QUANT_TYPE_TO_CONFIG = {
25
  "int8_weight_only": Int8WeightOnlyConfig,
26
  "int8_dynamic_activation_int8_weight": Int8DynamicActivationInt8WeightConfig,
27
  "float8_weight_only": Float8WeightOnlyConfig,
 
28
  }
29
 
30
 
@@ -164,16 +168,30 @@ It's quantized using the TorchAO library using the [torchao-my-repo](https://hug
164
 
165
 
166
  def quantize_model(
167
- model_name, quantization_type, group_size=128, auth_token=None, username=None
168
  ):
169
  print(f"Quantizing model: {quantization_type}")
 
170
  if (
171
- quantization_type == "int4_weight_only"
172
- or quantization_type == "int8_weight_only"
173
  ):
174
- quantization_config = TorchAoConfig(quantization_type, group_size=group_size)
175
- else:
 
 
 
 
 
 
 
 
 
 
176
  quantization_config = TorchAoConfig(quantization_type)
 
 
 
 
177
  model = AutoModel.from_pretrained(
178
  model_name,
179
  torch_dtype="auto",
@@ -181,7 +199,7 @@ def quantize_model(
181
  device_map="cpu",
182
  use_auth_token=auth_token.token,
183
  )
184
-
185
  return model
186
 
187
 
@@ -193,7 +211,10 @@ def save_model(
193
  username=None,
194
  auth_token=None,
195
  quantized_model_name=None,
 
 
196
  ):
 
197
  print("Saving quantized model")
198
  with tempfile.TemporaryDirectory() as tmpdirname:
199
  # Load and save the tokenizer
@@ -203,10 +224,11 @@ def save_model(
203
  tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)
204
 
205
  # Save the model
 
206
  model.save_pretrained(
207
  tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
208
  )
209
-
210
  if quantized_model_name:
211
  repo_name = f"{username}/{quantized_model_name}"
212
  else:
@@ -217,19 +239,21 @@ def save_model(
217
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
218
  else:
219
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
220
-
221
  model_card = create_model_card(model_name, quantization_type, group_size)
222
  with open(os.path.join(tmpdirname, "README.md"), "w") as f:
223
  f.write(model_card)
224
  # Push to Hub
225
  api = HfApi(token=auth_token.token)
226
- api.create_repo(repo_name, exist_ok=True)
 
227
  api.upload_folder(
228
  folder_path=tmpdirname,
229
  repo_id=repo_name,
230
  repo_type="model",
231
  )
232
-
 
233
  import io
234
  from contextlib import redirect_stdout
235
  import html
@@ -273,6 +297,7 @@ def quantize_and_save(
273
  quantization_type,
274
  group_size,
275
  quantized_model_name,
 
276
  ):
277
  if oauth_token is None:
278
  return """
@@ -332,8 +357,10 @@ def quantize_and_save(
332
  profile.username,
333
  oauth_token,
334
  quantized_model_name,
 
335
  )
336
  except Exception as e:
 
337
  return str(e)
338
 
339
 
@@ -464,24 +491,44 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
464
  "int4_weight_only",
465
  "int8_weight_only",
466
  "int8_dynamic_activation_int8_weight",
 
 
467
  "autoquant",
468
  ],
469
  value="int8_weight_only",
470
  filterable=False,
471
  show_label=False,
472
  )
 
473
  group_size = gr.Textbox(
474
  info="Group Size (only for int4_weight_only and int8_weight_only)",
475
  value="128",
476
- interactive=True,
477
  show_label=False,
478
  )
479
- quantized_model_name = gr.Textbox(
480
- info="Custom name for your quantized model (optional)",
481
- value="",
482
- interactive=True,
483
- show_label=False,
484
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  with gr.Column():
487
  quantize_button = gr.Button(
@@ -517,11 +564,10 @@ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
517
  - int8_weight_only typically reduces size by about 50%
518
  """
519
  )
520
-
521
  # Keep existing click handler
522
  quantize_button.click(
523
  fn=quantize_and_save,
524
- inputs=[model_name, quantization_type, group_size, quantized_model_name],
525
  outputs=[output_link],
526
  )
527
 
 
12
  Int8WeightOnlyConfig,
13
  Int8DynamicActivationInt8WeightConfig,
14
  Float8WeightOnlyConfig,
15
+ Float8DynamicActivationFloat8WeightConfig,
16
  )
17
 
18
  MAP_QUANT_TYPE_TO_NAME = {
19
  "int4_weight_only": "int4wo",
20
  "int8_weight_only": "int8wo",
21
+ "int8_dynamic_activation_int8_weight": "int8da8w8",
22
+ "float8_weight_only": "float8wo",
23
+ "float8_dynamic_activation_float8_weight": "float8da8w8",
24
  "autoquant": "autoquant",
25
  }
26
  MAP_QUANT_TYPE_TO_CONFIG = {
 
28
  "int8_weight_only": Int8WeightOnlyConfig,
29
  "int8_dynamic_activation_int8_weight": Int8DynamicActivationInt8WeightConfig,
30
  "float8_weight_only": Float8WeightOnlyConfig,
31
+ "float8_dynamic_activation_float8_weight": Float8DynamicActivationFloat8WeightConfig,
32
  }
33
 
34
 
 
168
 
169
 
170
  def quantize_model(
171
+ model_name, quantization_type, group_size=128, auth_token=None, username=None, progress=gr.Progress()
172
  ):
173
  print(f"Quantizing model: {quantization_type}")
174
+ progress(0, desc="Preparing Quantization")
175
  if (
176
+ quantization_type == "int8_weight_only"
 
177
  ):
178
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](
179
+ group_size=group_size
180
+ )
181
+ quantization_config = TorchAoConfig(quant_config)
182
+ elif quantization_type == "int4_weight_only":
183
+ from torchao.dtypes import Int4CPULayout
184
+
185
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](
186
+ group_size=group_size, layout=Int4CPULayout()
187
+ )
188
+ quantization_config = TorchAoConfig(quant_config)
189
+ elif quantization_type == "autoquant":
190
  quantization_config = TorchAoConfig(quantization_type)
191
+ else:
192
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]()
193
+ quantization_config = TorchAoConfig(quant_config)
194
+ progress(0.10, desc="Quantizing model")
195
  model = AutoModel.from_pretrained(
196
  model_name,
197
  torch_dtype="auto",
 
199
  device_map="cpu",
200
  use_auth_token=auth_token.token,
201
  )
202
+ progress(0.45, desc="Quantization completed")
203
  return model
204
 
205
 
 
211
  username=None,
212
  auth_token=None,
213
  quantized_model_name=None,
214
+ public=True,
215
+ progress=gr.Progress(),
216
  ):
217
+ progress(0.50, desc="Preparing to push")
218
  print("Saving quantized model")
219
  with tempfile.TemporaryDirectory() as tmpdirname:
220
  # Load and save the tokenizer
 
224
  tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)
225
 
226
  # Save the model
227
+ progress(0.60, desc="Saving model")
228
  model.save_pretrained(
229
  tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
230
  )
231
+
232
  if quantized_model_name:
233
  repo_name = f"{username}/{quantized_model_name}"
234
  else:
 
239
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
240
  else:
241
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
242
+ progress(0.70, desc="Creating model card")
243
  model_card = create_model_card(model_name, quantization_type, group_size)
244
  with open(os.path.join(tmpdirname, "README.md"), "w") as f:
245
  f.write(model_card)
246
  # Push to Hub
247
  api = HfApi(token=auth_token.token)
248
+ api.create_repo(repo_name, exist_ok=True, private=not public)
249
+ progress(0.80, desc="Pushing to Hub")
250
  api.upload_folder(
251
  folder_path=tmpdirname,
252
  repo_id=repo_name,
253
  repo_type="model",
254
  )
255
+ progress(1.00, desc="Pushing to Hub completed")
256
+
257
  import io
258
  from contextlib import redirect_stdout
259
  import html
 
297
  quantization_type,
298
  group_size,
299
  quantized_model_name,
300
+ public,
301
  ):
302
  if oauth_token is None:
303
  return """
 
357
  profile.username,
358
  oauth_token,
359
  quantized_model_name,
360
+ public,
361
  )
362
  except Exception as e:
363
+ # raise e
364
  return str(e)
365
 
366
 
 
491
  "int4_weight_only",
492
  "int8_weight_only",
493
  "int8_dynamic_activation_int8_weight",
494
+ "float8_weight_only",
495
+ "float8_dynamic_activation_float8_weight",
496
  "autoquant",
497
  ],
498
  value="int8_weight_only",
499
  filterable=False,
500
  show_label=False,
501
  )
502
+
503
  group_size = gr.Textbox(
504
  info="Group Size (only for int4_weight_only and int8_weight_only)",
505
  value="128",
506
+ interactive=(quantization_type.value == "int4_weight_only" or quantization_type.value == "int8_weight_only"),
507
  show_label=False,
508
  )
509
+
510
+ gr.Markdown(
511
+ """
512
+ ### 💾 Saving Settings
513
+ """
514
  )
515
+ with gr.Row():
516
+ quantized_model_name = gr.Textbox(
517
+ label="✏️ Model Name",
518
+ info="Model Name (optional : to override default)",
519
+ value="",
520
+ interactive=True,
521
+ elem_classes="model-name-textbox",
522
+ show_label=False,
523
+ )
524
+ with gr.Row():
525
+ public = gr.Checkbox(
526
+ label="🌐 Make model public",
527
+ info="If checked, the model will be publicly accessible",
528
+ value=True,
529
+ interactive=True,
530
+ show_label=True,
531
+ )
532
 
533
  with gr.Column():
534
  quantize_button = gr.Button(
 
564
  - int8_weight_only typically reduces size by about 50%
565
  """
566
  )
 
567
  # Keep existing click handler
568
  quantize_button.click(
569
  fn=quantize_and_save,
570
+ inputs=[model_name, quantization_type, group_size, quantized_model_name, public],
571
  outputs=[output_link],
572
  )
573