# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py import argparse import os import random import socket import tempfile import time import gradio as gr import numpy as np import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer from app import safety_check from app.sana_controlnet_pipeline import SanaControlNetPipeline STYLES = { "None": "{prompt}", "Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting", "Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed", "Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed", "Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed", "Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics", "Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", "Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style", } DEFAULT_STYLE_NAME = "None" STYLE_NAMES = list(STYLES.keys()) MAX_SEED = 1000000000 DEFAULT_SKETCH_GUIDANCE = 0.28 DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="config") parser.add_argument( "--model_path", nargs="?", default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth", type=str, help="Path to the model file (positional)", ) parser.add_argument("--output", default="./", type=str) parser.add_argument("--bs", default=1, type=int) parser.add_argument("--image_size", default=1024, type=int) parser.add_argument("--cfg_scale", default=5.0, type=float) parser.add_argument("--pag_scale", default=2.0, type=float) parser.add_argument("--seed", default=42, type=int) parser.add_argument("--step", default=-1, type=int) parser.add_argument("--custom_image_size", default=None, type=int) parser.add_argument("--share", action="store_true") parser.add_argument( "--shield_model_path", type=str, help="The path to shield model, we employ ShieldGemma-2B by default.", default="google/shieldgemma-2b", ) return parser.parse_known_args()[0] args = get_args() if torch.cuda.is_available(): model_path = args.model_path pipe = SanaControlNetPipeline(args.config) pipe.from_pretrained(model_path) pipe.register_progress_bar(gr.Progress()) # safety checker safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) safety_checker_model = AutoModelForCausalLM.from_pretrained( args.shield_model_path, device_map="auto", torch_dtype=torch.bfloat16, ).to(device) def save_image(img): if isinstance(img, dict): img = img["composite"] temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(temp_file.name) return temp_file.name def norm_ip(img, low, high): img.clamp_(min=low, max=high) img.sub_(low).div_(max(high - low, 1e-5)) return img @torch.no_grad() @torch.inference_mode() def run( image, prompt: str, prompt_template: str, sketch_thickness: int, guidance_scale: float, inference_steps: int, seed: int, blend_alpha: float, ) -> tuple[Image, str]: print(f"Prompt: {prompt}") image_numpy = np.array(image["composite"].convert("RGB")) if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): return blank_image, "Please input the prompt or draw something." if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): prompt = "A red heart." prompt = prompt_template.format(prompt=prompt) pipe.set_blend_alpha(blend_alpha) start_time = time.time() images = pipe( prompt=prompt, ref_image=image["composite"], guidance_scale=guidance_scale, num_inference_steps=inference_steps, num_images_per_prompt=1, sketch_thickness=sketch_thickness, generator=torch.Generator(device=device).manual_seed(seed), ) latency = time.time() - start_time if latency < 1: latency = latency * 1000 latency_str = f"{latency:.2f}ms" else: latency_str = f"{latency:.2f}s" torch.cuda.empty_cache() img = [ Image.fromarray( norm_ip(img, -1, 1) .mul(255) .add_(0.5) .clamp_(0, 255) .permute(1, 2, 0) .to("cpu", torch.uint8) .numpy() .astype(np.uint8) ) for img in images ] img = img[0] return img, latency_str model_size = "1.6" if "1600M" in args.model_path else "0.6" title = f"""
logo
""" DESCRIPTION = f"""

Sana-ControlNet-{model_size}B{args.image_size}px

Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer

[Paper] [Github] [Project]

Powered by DC-AE with 32x latent space,

running on node {socket.gethostname()}.

Unsafe word will give you a 'Red Heart' in the image instead.

""" if model_size == "0.6": DESCRIPTION += "\n

0.6B model's text rendering ability is limited.

" if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo: gr.Markdown(title) gr.HTML(DESCRIPTION) with gr.Row(elem_id="main_row"): with gr.Column(elem_id="column_input"): gr.Markdown("## INPUT", elem_id="input_header") with gr.Group(): canvas = gr.Sketchpad( value=blank_image, height=640, image_mode="RGB", sources=["upload", "clipboard"], type="pil", label="Sketch", show_label=False, show_download_button=True, interactive=True, transforms=[], canvas_size=(1024, 1024), scale=1, brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"), format="png", layers=False, ) with gr.Row(): prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) run_button = gr.Button("Run", scale=1, elem_id="run_button") download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch") with gr.Row(): style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) prompt_template = gr.Textbox( label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1 ) with gr.Row(): sketch_thickness = gr.Slider( label="Sketch Thickness", minimum=1, maximum=4, step=1, value=2, ) with gr.Row(): inference_steps = gr.Slider( label="Sampling steps", minimum=5, maximum=40, step=1, value=20, ) guidance_scale = gr.Slider( label="CFG Guidance scale", minimum=1, maximum=10, step=0.1, value=4.5, ) blend_alpha = gr.Slider( label="Blend Alpha", minimum=0, maximum=1, step=0.1, value=0, ) with gr.Row(): seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") with gr.Column(elem_id="column_output"): gr.Markdown("## OUTPUT", elem_id="output_header") with gr.Group(): result = gr.Image( format="png", height=640, image_mode="RGB", type="pil", label="Result", show_label=False, show_download_button=True, interactive=False, elem_id="output_image", ) latency_result = gr.Text(label="Inference Latency", show_label=True) download_result = gr.DownloadButton("Download Result", elem_id="download_result") gr.Markdown("### Instructions") gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") gr.Markdown("**2**. Start sketching or upload a reference image") gr.Markdown("**3**. Change the image style using a style template") gr.Markdown("**4**. Try different seeds to generate different results") run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha] run_outputs = [result, latency_result] randomize_seed.click( lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False, ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) style.change( lambda x: STYLES[x], inputs=[style], outputs=[prompt_template], api_name=False, queue=False, ).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False) gr.on( triggers=[prompt.submit, run_button.click, canvas.change], fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False, ) download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch) download_result.click(fn=save_image, inputs=result, outputs=download_result) gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") if __name__ == "__main__": demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)