torchao-my-repo / app.py
MekkCyber
first
b5887d5
import gradio as gr
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
import tempfile
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import list_models
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from packaging import version
import os
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Float8WeightOnlyConfig,
)
MAP_QUANT_TYPE_TO_NAME = {
"int4_weight_only": "int4wo",
"int8_weight_only": "int8wo",
"int8_dynamic_activation_int8_weight": "int8da8w",
"autoquant": "autoquant",
}
MAP_QUANT_TYPE_TO_CONFIG = {
"int4_weight_only": Int4WeightOnlyConfig,
"int8_weight_only": Int8WeightOnlyConfig,
"int8_dynamic_activation_int8_weight": Int8DynamicActivationInt8WeightConfig,
"float8_weight_only": Float8WeightOnlyConfig,
}
def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
# ^ expect a gr.OAuthProfile object as input to get the user's profile
# if the user is not logged in, profile will be None
if profile is None:
return "Hello !"
return f"Hello {profile.name} !"
def check_model_exists(
oauth_token: gr.OAuthToken | None,
username,
quantization_type,
group_size,
model_name,
quantized_model_name,
):
"""Check if a model exists in the user's Hugging Face repository."""
try:
models = list_models(author=username, token=oauth_token.token)
model_names = [model.id for model in models]
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
if (
quantization_type == "int4_weight_only"
or quantization_type == "int8_weight_only"
) and (group_size is not None):
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
if repo_name in model_names:
return f"Model '{repo_name}' already exists in your repository."
else:
return None # Model does not exist
except Exception as e:
return f"Error checking model existence: {str(e)}"
def create_model_card(model_name, quantization_type, group_size):
# Try to download the original README
original_readme = ""
original_yaml_header = ""
try:
# Download the README.md file from the original model
model_path = snapshot_download(
repo_id=model_name, allow_patterns=["README.md"], repo_type="model"
)
readme_path = os.path.join(model_path, "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
content = f.read()
if content.startswith("---"):
parts = content.split("---", 2)
if len(parts) >= 3:
original_yaml_header = parts[1]
original_readme = "---".join(parts[2:])
else:
original_readme = content
else:
original_readme = content
except Exception as e:
print(f"Error reading original README: {str(e)}")
original_readme = ""
# Create new YAML header with base_model field
yaml_header = f"""---
base_model:
- {model_name}"""
# Add any original YAML fields except base_model
if original_yaml_header:
in_base_model_section = False
found_tags = False
for line in original_yaml_header.strip().split("\n"):
# Skip if we're in a base_model section that continues to the next line
if in_base_model_section:
if (
line.strip().startswith("-")
or not line.strip()
or line.startswith(" ")
):
continue
else:
in_base_model_section = False
# Check for base_model field
if line.strip().startswith("base_model:"):
in_base_model_section = True
# If base_model has inline value (like "base_model: model_name")
if ":" in line and len(line.split(":", 1)[1].strip()) > 0:
in_base_model_section = False
continue
# Check for tags field and add bnb-my-repo
if line.strip().startswith("tags:"):
found_tags = True
yaml_header += f"\n{line}"
yaml_header += "\n- torchao-my-repo"
continue
yaml_header += f"\n{line}"
# If tags field wasn't found, add it
if not found_tags:
yaml_header += "\ntags:"
yaml_header += "\n- torchao-my-repo"
# Complete the YAML header
yaml_header += "\n---"
# Create the quantization info section
quant_info = f"""
# {model_name} (Quantized)
## Description
This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}).
It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space.
## Quantization Details
- **Quantization Type**: {quantization_type}
- **Group Size**: {group_size}
"""
# Combine everything
model_card = yaml_header + quant_info
# Append original README content if available
if original_readme and not original_readme.isspace():
model_card += "\n\n# πŸ“„ Original Model Information\n\n" + original_readme
return model_card
def quantize_model(
model_name, quantization_type, group_size=128, auth_token=None, username=None
):
print(f"Quantizing model: {quantization_type}")
if (
quantization_type == "int4_weight_only"
or quantization_type == "int8_weight_only"
):
quantization_config = TorchAoConfig(quantization_type, group_size=group_size)
else:
quantization_config = TorchAoConfig(quantization_type)
model = AutoModel.from_pretrained(
model_name,
torch_dtype="auto",
quantization_config=quantization_config,
device_map="cpu",
use_auth_token=auth_token.token,
)
return model
def save_model(
model,
model_name,
quantization_type,
group_size=128,
username=None,
auth_token=None,
quantized_model_name=None,
):
print("Saving quantized model")
with tempfile.TemporaryDirectory() as tmpdirname:
# Load and save the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name, use_auth_token=auth_token.token
)
tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)
# Save the model
model.save_pretrained(
tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
)
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
if (
quantization_type == "int4_weight_only"
or quantization_type == "int8_weight_only"
) and (group_size is not None):
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
model_card = create_model_card(model_name, quantization_type, group_size)
with open(os.path.join(tmpdirname, "README.md"), "w") as f:
f.write(model_card)
# Push to Hub
api = HfApi(token=auth_token.token)
api.create_repo(repo_name, exist_ok=True)
api.upload_folder(
folder_path=tmpdirname,
repo_id=repo_name,
repo_type="model",
)
import io
from contextlib import redirect_stdout
import html
# Capture the model architecture string
f = io.StringIO()
with redirect_stdout(f):
print(model)
model_architecture_str = f.getvalue()
# Escape HTML characters and format with line breaks
model_architecture_str_html = html.escape(model_architecture_str).replace(
"\n", "<br/>"
)
# Format it for display in markdown with proper styling
model_architecture_info = f"""
<div class="model-architecture-container" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
<h3 style="margin-top: 0; color: #2E7D32;">πŸ“‹ Model Architecture</h3>
<div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;">
<div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div>
</div>
</div>
"""
repo_link = f"""
<div class="repo-link" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
<h3 style="margin-top: 0; color: #2E7D32;">πŸ”— Repository Link</h3>
<p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a></p>
</div>
"""
return (
f"<h1>πŸŽ‰ Quantization Completed</h1><br/>{repo_link}{model_architecture_info}"
)
def quantize_and_save(
profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
model_name,
quantization_type,
group_size,
quantized_model_name,
):
if oauth_token is None:
return """
<div class="error-box">
<h3>❌ Authentication Error</h3>
<p>Please sign in to your HuggingFace account to use the quantizer.</p>
</div>
"""
if not profile:
return """
<div class="error-box">
<h3>❌ Authentication Error</h3>
<p>Please sign in to your HuggingFace account to use the quantizer.</p>
</div>
"""
if not group_size.isdigit():
if group_size != "":
return """
<div class="error-box">
<h3>❌ Group Size Error</h3>
<p>Group Size is a number for int4_weight_only and int8_weight_only or empty for int8_weight_only</p>
</div>
"""
if group_size and group_size.strip():
group_size = int(group_size)
else:
group_size = None
exists_message = check_model_exists(
oauth_token,
profile.username,
quantization_type,
group_size,
model_name,
quantized_model_name,
)
if exists_message:
return f"""
<div class="warning-box">
<h3>⚠️ Model Already Exists</h3>
<p>{exists_message}</p>
</div>
"""
# if quantization_type == "int4_weight_only" :
# return "int4_weight_only not supported on cpu"
try:
quantized_model = quantize_model(
model_name, quantization_type, group_size, oauth_token, profile.username
)
return save_model(
quantized_model,
model_name,
quantization_type,
group_size,
profile.username,
oauth_token,
quantized_model_name,
)
except Exception as e:
return str(e)
def get_model_size(model):
"""
Calculate the size of a PyTorch model in gigabytes.
Args:
model: PyTorch model
Returns:
float: Size of the model in GB
"""
# Get model state dict
state_dict = model.state_dict()
# Calculate total size in bytes
total_size = 0
for param in state_dict.values():
# Calculate bytes for each parameter
total_size += param.nelement() * param.element_size()
# Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes)
size_gb = total_size / (1024**3)
size_gb = round(size_gb, 2)
return size_gb
# Add enhanced CSS styling
css = """
/* Custom CSS for enhanced UI */
.gradio-container {overflow-y: auto;}
/* Fix alignment for radio buttons and dropdowns */
.gradio-radio, .gradio-dropdown {
display: flex !important;
align-items: center !important;
margin: 10px 0 !important;
}
/* Consistent spacing and alignment */
.gradio-dropdown, .gradio-textbox, .gradio-radio {
margin-bottom: 12px !important;
width: 100% !important;
}
/* Quantize button styling with glow effect */
button[variant="primary"] {
background: linear-gradient(135deg, #3B82F6, #10B981) !important;
color: white !important;
padding: 16px 32px !important;
font-size: 1.1rem !important;
font-weight: 700 !important;
border: none !important;
border-radius: 12px !important;
box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
position: relative;
overflow: hidden;
animation: glow 1.5s ease-in-out infinite alternate;
}
button[variant="primary"]::before {
content: "✨ ";
}
button[variant="primary"]:hover {
transform: translateY(-5px) scale(1.05) !important;
box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important;
}
@keyframes glow {
from {
box-shadow: 0 0 10px rgba(59, 130, 246, 0.5);
}
to {
box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5);
}
}
/* Login button styling */
#login-button {
background: linear-gradient(135deg, #3B82F6, #10B981) !important;
color: white !important;
font-weight: 700 !important;
border: none !important;
border-radius: 12px !important;
box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
transition: all 0.3s ease !important;
max-width: 300px !important;
margin: 0 auto !important;
}
"""
# Update the main app layout
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
gr.Markdown(
"""
# πŸ€— TorchAO Model Quantizer ✨
Quantize your favorite Hugging Face models using TorchAO and save them to your profile!
<br/>
"""
)
gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
m1 = gr.Markdown()
demo.load(hello, inputs=None, outputs=m1)
with gr.Row():
with gr.Column():
with gr.Row():
model_name = HuggingfaceHubSearch(
label="πŸ” Hub Model ID",
placeholder="Search for model id on Huggingface",
search_type="model",
)
gr.Markdown("""### βš™οΈ Quantization Settings""")
with gr.Row():
with gr.Column():
quantization_type = gr.Dropdown(
info="Select the Quantization method",
choices=[
"int4_weight_only",
"int8_weight_only",
"int8_dynamic_activation_int8_weight",
"autoquant",
],
value="int8_weight_only",
filterable=False,
show_label=False,
)
group_size = gr.Textbox(
info="Group Size (only for int4_weight_only and int8_weight_only)",
value="128",
interactive=True,
show_label=False,
)
quantized_model_name = gr.Textbox(
info="Custom name for your quantized model (optional)",
value="",
interactive=True,
show_label=False,
)
with gr.Column():
quantize_button = gr.Button(
"πŸš€ Quantize and Push to Hub", variant="primary"
)
output_link = gr.Markdown(
label="πŸ”— Quantized Model Info", container=True, min_height=200
)
# Add information section
with gr.Accordion("πŸ“š About TorchAO Quantization", open=True):
gr.Markdown(
"""
## πŸ“ Quantization Options
### Quantization Types
- **int4_weight_only**: 4-bit weight-only quantization
- **int8_weight_only**: 8-bit weight-only quantization
- **int8_dynamic_activation_int8_weight**: 8-bit quantization for both weights and activations
### Group Size
- Only applicable for int4_weight_only and int8_weight_only quantization
- Default value is 128
- Affects the granularity of quantization
## πŸ” How It Works
1. Downloads the original model
2. Applies TorchAO quantization with your selected settings
3. Uploads the quantized model to your HuggingFace account
## πŸ“Š Memory Benefits
- int4_weight_only can reduce model size by up to 75%
- int8_weight_only typically reduces size by about 50%
"""
)
# Keep existing click handler
quantize_button.click(
fn=quantize_and_save,
inputs=[model_name, quantization_type, group_size, quantized_model_name],
outputs=[output_link],
)
# Launch the app
demo.launch(share=True)