import gradio as gr import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel import tempfile from huggingface_hub import HfApi, snapshot_download from huggingface_hub import list_models from gradio_huggingfacehub_search import HuggingfaceHubSearch from packaging import version import os from torchao.quantization import ( Int4WeightOnlyConfig, Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Float8WeightOnlyConfig, ) MAP_QUANT_TYPE_TO_NAME = { "int4_weight_only": "int4wo", "int8_weight_only": "int8wo", "int8_dynamic_activation_int8_weight": "int8da8w", "autoquant": "autoquant", } MAP_QUANT_TYPE_TO_CONFIG = { "int4_weight_only": Int4WeightOnlyConfig, "int8_weight_only": Int8WeightOnlyConfig, "int8_dynamic_activation_int8_weight": Int8DynamicActivationInt8WeightConfig, "float8_weight_only": Float8WeightOnlyConfig, } def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: # ^ expect a gr.OAuthProfile object as input to get the user's profile # if the user is not logged in, profile will be None if profile is None: return "Hello !" return f"Hello {profile.name} !" def check_model_exists( oauth_token: gr.OAuthToken | None, username, quantization_type, group_size, model_name, quantized_model_name, ): """Check if a model exists in the user's Hugging Face repository.""" try: models = list_models(author=username, token=oauth_token.token) model_names = [model.id for model in models] if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: if ( quantization_type == "int4_weight_only" or quantization_type == "int8_weight_only" ) and (group_size is not None): repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}" if repo_name in model_names: return f"Model '{repo_name}' already exists in your repository." else: return None # Model does not exist except Exception as e: return f"Error checking model existence: {str(e)}" def create_model_card(model_name, quantization_type, group_size): # Try to download the original README original_readme = "" original_yaml_header = "" try: # Download the README.md file from the original model model_path = snapshot_download( repo_id=model_name, allow_patterns=["README.md"], repo_type="model" ) readme_path = os.path.join(model_path, "README.md") if os.path.exists(readme_path): with open(readme_path, "r", encoding="utf-8") as f: content = f.read() if content.startswith("---"): parts = content.split("---", 2) if len(parts) >= 3: original_yaml_header = parts[1] original_readme = "---".join(parts[2:]) else: original_readme = content else: original_readme = content except Exception as e: print(f"Error reading original README: {str(e)}") original_readme = "" # Create new YAML header with base_model field yaml_header = f"""--- base_model: - {model_name}""" # Add any original YAML fields except base_model if original_yaml_header: in_base_model_section = False found_tags = False for line in original_yaml_header.strip().split("\n"): # Skip if we're in a base_model section that continues to the next line if in_base_model_section: if ( line.strip().startswith("-") or not line.strip() or line.startswith(" ") ): continue else: in_base_model_section = False # Check for base_model field if line.strip().startswith("base_model:"): in_base_model_section = True # If base_model has inline value (like "base_model: model_name") if ":" in line and len(line.split(":", 1)[1].strip()) > 0: in_base_model_section = False continue # Check for tags field and add bnb-my-repo if line.strip().startswith("tags:"): found_tags = True yaml_header += f"\n{line}" yaml_header += "\n- torchao-my-repo" continue yaml_header += f"\n{line}" # If tags field wasn't found, add it if not found_tags: yaml_header += "\ntags:" yaml_header += "\n- torchao-my-repo" # Complete the YAML header yaml_header += "\n---" # Create the quantization info section quant_info = f""" # {model_name} (Quantized) ## Description This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}). It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space. ## Quantization Details - **Quantization Type**: {quantization_type} - **Group Size**: {group_size} """ # Combine everything model_card = yaml_header + quant_info # Append original README content if available if original_readme and not original_readme.isspace(): model_card += "\n\n# 📄 Original Model Information\n\n" + original_readme return model_card def quantize_model( model_name, quantization_type, group_size=128, auth_token=None, username=None ): print(f"Quantizing model: {quantization_type}") if ( quantization_type == "int4_weight_only" or quantization_type == "int8_weight_only" ): quantization_config = TorchAoConfig(quantization_type, group_size=group_size) else: quantization_config = TorchAoConfig(quantization_type) model = AutoModel.from_pretrained( model_name, torch_dtype="auto", quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, ) return model def save_model( model, model_name, quantization_type, group_size=128, username=None, auth_token=None, quantized_model_name=None, ): print("Saving quantized model") with tempfile.TemporaryDirectory() as tmpdirname: # Load and save the tokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, use_auth_token=auth_token.token ) tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token) # Save the model model.save_pretrained( tmpdirname, safe_serialization=False, use_auth_token=auth_token.token ) if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: if ( quantization_type == "int4_weight_only" or quantization_type == "int8_weight_only" ) and (group_size is not None): repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}" model_card = create_model_card(model_name, quantization_type, group_size) with open(os.path.join(tmpdirname, "README.md"), "w") as f: f.write(model_card) # Push to Hub api = HfApi(token=auth_token.token) api.create_repo(repo_name, exist_ok=True) api.upload_folder( folder_path=tmpdirname, repo_id=repo_name, repo_type="model", ) import io from contextlib import redirect_stdout import html # Capture the model architecture string f = io.StringIO() with redirect_stdout(f): print(model) model_architecture_str = f.getvalue() # Escape HTML characters and format with line breaks model_architecture_str_html = html.escape(model_architecture_str).replace( "\n", "
" ) # Format it for display in markdown with proper styling model_architecture_info = f"""

📋 Model Architecture

{model_architecture_str_html}
""" repo_link = f""" """ return ( f"

🎉 Quantization Completed


{repo_link}{model_architecture_info}" ) def quantize_and_save( profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quantization_type, group_size, quantized_model_name, ): if oauth_token is None: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" if not profile: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" if not group_size.isdigit(): if group_size != "": return """

❌ Group Size Error

Group Size is a number for int4_weight_only and int8_weight_only or empty for int8_weight_only

""" if group_size and group_size.strip(): group_size = int(group_size) else: group_size = None exists_message = check_model_exists( oauth_token, profile.username, quantization_type, group_size, model_name, quantized_model_name, ) if exists_message: return f"""

⚠️ Model Already Exists

{exists_message}

""" # if quantization_type == "int4_weight_only" : # return "int4_weight_only not supported on cpu" try: quantized_model = quantize_model( model_name, quantization_type, group_size, oauth_token, profile.username ) return save_model( quantized_model, model_name, quantization_type, group_size, profile.username, oauth_token, quantized_model_name, ) except Exception as e: return str(e) def get_model_size(model): """ Calculate the size of a PyTorch model in gigabytes. Args: model: PyTorch model Returns: float: Size of the model in GB """ # Get model state dict state_dict = model.state_dict() # Calculate total size in bytes total_size = 0 for param in state_dict.values(): # Calculate bytes for each parameter total_size += param.nelement() * param.element_size() # Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes) size_gb = total_size / (1024**3) size_gb = round(size_gb, 2) return size_gb # Add enhanced CSS styling css = """ /* Custom CSS for enhanced UI */ .gradio-container {overflow-y: auto;} /* Fix alignment for radio buttons and dropdowns */ .gradio-radio, .gradio-dropdown { display: flex !important; align-items: center !important; margin: 10px 0 !important; } /* Consistent spacing and alignment */ .gradio-dropdown, .gradio-textbox, .gradio-radio { margin-bottom: 12px !important; width: 100% !important; } /* Quantize button styling with glow effect */ button[variant="primary"] { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; padding: 16px 32px !important; font-size: 1.1rem !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; animation: glow 1.5s ease-in-out infinite alternate; } button[variant="primary"]::before { content: "✨ "; } button[variant="primary"]:hover { transform: translateY(-5px) scale(1.05) !important; box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; } @keyframes glow { from { box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); } to { box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5); } } /* Login button styling */ #login-button { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; transition: all 0.3s ease !important; max-width: 300px !important; margin: 0 auto !important; } """ # Update the main app layout with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: gr.Markdown( """ # 🤗 TorchAO Model Quantizer ✨ Quantize your favorite Hugging Face models using TorchAO and save them to your profile!
""" ) gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) m1 = gr.Markdown() demo.load(hello, inputs=None, outputs=m1) with gr.Row(): with gr.Column(): with gr.Row(): model_name = HuggingfaceHubSearch( label="🔍 Hub Model ID", placeholder="Search for model id on Huggingface", search_type="model", ) gr.Markdown("""### ⚙️ Quantization Settings""") with gr.Row(): with gr.Column(): quantization_type = gr.Dropdown( info="Select the Quantization method", choices=[ "int4_weight_only", "int8_weight_only", "int8_dynamic_activation_int8_weight", "autoquant", ], value="int8_weight_only", filterable=False, show_label=False, ) group_size = gr.Textbox( info="Group Size (only for int4_weight_only and int8_weight_only)", value="128", interactive=True, show_label=False, ) quantized_model_name = gr.Textbox( info="Custom name for your quantized model (optional)", value="", interactive=True, show_label=False, ) with gr.Column(): quantize_button = gr.Button( "🚀 Quantize and Push to Hub", variant="primary" ) output_link = gr.Markdown( label="🔗 Quantized Model Info", container=True, min_height=200 ) # Add information section with gr.Accordion("📚 About TorchAO Quantization", open=True): gr.Markdown( """ ## 📝 Quantization Options ### Quantization Types - **int4_weight_only**: 4-bit weight-only quantization - **int8_weight_only**: 8-bit weight-only quantization - **int8_dynamic_activation_int8_weight**: 8-bit quantization for both weights and activations ### Group Size - Only applicable for int4_weight_only and int8_weight_only quantization - Default value is 128 - Affects the granularity of quantization ## 🔍 How It Works 1. Downloads the original model 2. Applies TorchAO quantization with your selected settings 3. Uploads the quantized model to your HuggingFace account ## 📊 Memory Benefits - int4_weight_only can reduce model size by up to 75% - int8_weight_only typically reduces size by about 50% """ ) # Keep existing click handler quantize_button.click( fn=quantize_and_save, inputs=[model_name, quantization_type, group_size, quantized_model_name], outputs=[output_link], ) # Launch the app demo.launch(share=True)