import gradio as gr
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
import tempfile
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import list_models
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from packaging import version
import os
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Float8WeightOnlyConfig,
)
MAP_QUANT_TYPE_TO_NAME = {
"int4_weight_only": "int4wo",
"int8_weight_only": "int8wo",
"int8_dynamic_activation_int8_weight": "int8da8w",
"autoquant": "autoquant",
}
MAP_QUANT_TYPE_TO_CONFIG = {
"int4_weight_only": Int4WeightOnlyConfig,
"int8_weight_only": Int8WeightOnlyConfig,
"int8_dynamic_activation_int8_weight": Int8DynamicActivationInt8WeightConfig,
"float8_weight_only": Float8WeightOnlyConfig,
}
def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
# ^ expect a gr.OAuthProfile object as input to get the user's profile
# if the user is not logged in, profile will be None
if profile is None:
return "Hello !"
return f"Hello {profile.name} !"
def check_model_exists(
oauth_token: gr.OAuthToken | None,
username,
quantization_type,
group_size,
model_name,
quantized_model_name,
):
"""Check if a model exists in the user's Hugging Face repository."""
try:
models = list_models(author=username, token=oauth_token.token)
model_names = [model.id for model in models]
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
if (
quantization_type == "int4_weight_only"
or quantization_type == "int8_weight_only"
) and (group_size is not None):
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
if repo_name in model_names:
return f"Model '{repo_name}' already exists in your repository."
else:
return None # Model does not exist
except Exception as e:
return f"Error checking model existence: {str(e)}"
def create_model_card(model_name, quantization_type, group_size):
# Try to download the original README
original_readme = ""
original_yaml_header = ""
try:
# Download the README.md file from the original model
model_path = snapshot_download(
repo_id=model_name, allow_patterns=["README.md"], repo_type="model"
)
readme_path = os.path.join(model_path, "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
if content.startswith("---"):
parts = content.split("---", 2)
if len(parts) >= 3:
original_yaml_header = parts[1]
original_readme = "---".join(parts[2:])
else:
original_readme = content
else:
original_readme = content
except Exception as e:
print(f"Error reading original README: {str(e)}")
original_readme = ""
# Create new YAML header with base_model field
yaml_header = f"""---
base_model:
- {model_name}"""
# Add any original YAML fields except base_model
if original_yaml_header:
in_base_model_section = False
found_tags = False
for line in original_yaml_header.strip().split("\n"):
# Skip if we're in a base_model section that continues to the next line
if in_base_model_section:
if (
line.strip().startswith("-")
or not line.strip()
or line.startswith(" ")
):
continue
else:
in_base_model_section = False
# Check for base_model field
if line.strip().startswith("base_model:"):
in_base_model_section = True
# If base_model has inline value (like "base_model: model_name")
if ":" in line and len(line.split(":", 1)[1].strip()) > 0:
in_base_model_section = False
continue
# Check for tags field and add bnb-my-repo
if line.strip().startswith("tags:"):
found_tags = True
yaml_header += f"\n{line}"
yaml_header += "\n- torchao-my-repo"
continue
yaml_header += f"\n{line}"
# If tags field wasn't found, add it
if not found_tags:
yaml_header += "\ntags:"
yaml_header += "\n- torchao-my-repo"
# Complete the YAML header
yaml_header += "\n---"
# Create the quantization info section
quant_info = f"""
# {model_name} (Quantized)
## Description
This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}).
It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space.
## Quantization Details
- **Quantization Type**: {quantization_type}
- **Group Size**: {group_size}
"""
# Combine everything
model_card = yaml_header + quant_info
# Append original README content if available
if original_readme and not original_readme.isspace():
model_card += "\n\n# 📄 Original Model Information\n\n" + original_readme
return model_card
def quantize_model(
model_name, quantization_type, group_size=128, auth_token=None, username=None
):
print(f"Quantizing model: {quantization_type}")
if (
quantization_type == "int4_weight_only"
or quantization_type == "int8_weight_only"
):
quantization_config = TorchAoConfig(quantization_type, group_size=group_size)
else:
quantization_config = TorchAoConfig(quantization_type)
model = AutoModel.from_pretrained(
model_name,
torch_dtype="auto",
quantization_config=quantization_config,
device_map="cpu",
use_auth_token=auth_token.token,
)
return model
def save_model(
model,
model_name,
quantization_type,
group_size=128,
username=None,
auth_token=None,
quantized_model_name=None,
):
print("Saving quantized model")
with tempfile.TemporaryDirectory() as tmpdirname:
# Load and save the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name, use_auth_token=auth_token.token
)
tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)
# Save the model
model.save_pretrained(
tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
)
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
if (
quantization_type == "int4_weight_only"
or quantization_type == "int8_weight_only"
) and (group_size is not None):
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
model_card = create_model_card(model_name, quantization_type, group_size)
with open(os.path.join(tmpdirname, "README.md"), "w") as f:
f.write(model_card)
# Push to Hub
api = HfApi(token=auth_token.token)
api.create_repo(repo_name, exist_ok=True)
api.upload_folder(
folder_path=tmpdirname,
repo_id=repo_name,
repo_type="model",
)
import io
from contextlib import redirect_stdout
import html
# Capture the model architecture string
f = io.StringIO()
with redirect_stdout(f):
print(model)
model_architecture_str = f.getvalue()
# Escape HTML characters and format with line breaks
model_architecture_str_html = html.escape(model_architecture_str).replace(
"\n", "
"
)
# Format it for display in markdown with proper styling
model_architecture_info = f"""
Find your repo here: {repo_name}
Please sign in to your HuggingFace account to use the quantizer.
Please sign in to your HuggingFace account to use the quantizer.
Group Size is a number for int4_weight_only and int8_weight_only or empty for int8_weight_only
{exists_message}