import os import random import uuid import json import gradio as gr import numpy as np from PIL import Image import spaces import torch from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler # Use environment variables for flexibility MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash") MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once # Determine device and load model outside of function for efficiency device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") pipe = StableDiffusionXLPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=True, add_watermarker=False, ).to(device) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) # Torch compile for potential speedup (experimental) if USE_TORCH_COMPILE: pipe.compile() # CPU offloading for larger RAM capacity (experimental) if ENABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() MAX_SEED = np.iinfo(np.int32).max def save_image(img): unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed @spaces.GPU(duration=35, enable_queue=True) def generate( prompt: str, negative_prompt: str = "", use_negative_prompt: bool = False, seed: int = 1, width: int = 1024, height: int = 1024, guidance_scale: float = 3, num_inference_steps: int = 30, randomize_seed: bool = False, use_resolution_binning: bool = True, num_images: int = 1, # Number of images to generate progress=gr.Progress(track_tqdm=True), ): seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator(device=device).manual_seed(seed) # Improved options handling options = { "prompt": [prompt] * num_images, "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None, "width": width, "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "generator": generator, "output_type": "pil", } # Use resolution binning for faster generation with less VRAM usage if use_resolution_binning: options["use_resolution_binning"] = True # Generate images potentially in batches images = [] for i in range(0, num_images, BATCH_SIZE): batch_options = options.copy() batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE] if "negative_prompt" in batch_options: batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE] images.extend(pipe(**batch_options).images) image_paths = [save_image(img) for img in images] return image_paths, seed examples = [ "a cat eating a piece of cheese", "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k", "Ironman VS Hulk, ultrarealistic", "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk", "Kids going to school, Anime style" ] css = ''' .gradio-container{max-width: 700px !important} h1{text-align:center} footer { visibility: hidden } .wheel-and-hamster { --dur: 1s; position: relative; width: 12em; height: 12em; font-size: 14px; } .wheel, .hamster, .hamster div, .spoke { position: absolute; } .wheel, .spoke { border-radius: 50%; top: 0; left: 0; width: 100%; height: 100%; } .wheel { background: radial-gradient(100% 100% at center,hsla(0,0%,60%,0) 47.8%,hsl(0,0%,60%) 48%); z-index: 2; } .hamster { animation: hamster var(--dur) ease-in-out infinite; top: 50%; left: calc(50% - 3.5em); width: 7em; height: 3.75em; transform: rotate(4deg) translate(-0.8em,1.85em); transform-origin: 50% 0; z-index: 1; } .hamster__head { animation: hamsterHead var(--dur) ease-in-out infinite; background: hsl(30,90%,55%); border-radius: 70% 30% 0 100% / 40% 25% 25% 60%; box-shadow: 0 -0.25em 0 hsl(30,90%,80%) inset, 0.75em -1.55em 0 hsl(30,90%,90%) inset; top: 0; left: -2em; width: 2.75em; height: 2.5em; transform-origin: 100% 50%; } .hamster__ear { animation: hamsterEar var(--dur) ease-in-out infinite; background: hsl(0,90%,85%); border-radius: 50%; box-shadow: -0.25em 0 hsl(30,90%,55%) inset; top: -0.25em; right: -0.25em; width: 0.75em; height: 0.75em; transform-origin: 50% 75%; } .hamster__eye { animation: hamsterEye var(--dur) linear infinite; background-color: hsl(0,0%,0%); border-radius: 50%; top: 0.375em; left: 1.25em; width: 0.5em; height: 0.5em; } .hamster__nose { background: hsl(0,90%,75%); border-radius: 35% 65% 85% 15% / 70% 50% 50% 30%; top: 0.75em; left: 0; width: 0.2em; height: 0.25em; } .hamster__body { animation: hamsterBody var(--dur) ease-in-out infinite; background: hsl(30,90%,90%); border-radius: 50% 30% 50% 30% / 15% 60% 40% 40%; box-shadow: 0.1em 0.75em 0 hsl(30,90%,55%) inset, 0.15em -0.5em 0 hsl(30,90%,80%) inset; top: 0.25em; left: 2em; width: 4.5em; height: 3em; transform-origin: 17% 50%; transform-style: preserve-3d; } .hamster__limb--fr, .hamster__limb--fl { clip-path: polygon(0 0,100% 0,70% 80%,60% 100%,0% 100%,40% 80%); top: 2em; left: 0.5em; width: 1em; height: 1.5em; transform-origin: 50% 0; } .hamster__limb--fr { animation: hamsterFRLimb var(--dur) linear infinite; background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%); transform: rotate(15deg) translateZ(-1px); } .hamster__limb--fl { animation: hamsterFLLimb var(--dur) linear infinite; background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%); transform: rotate(-60deg) translateZ(-1px); } .hamster__limb--br, .hamster__limb--bl { clip-path: polygon(0 0,100% 0,100% 20%,30% 100%,0% 100%); top: 1.25em; left: 2.8em; width: 1.5em; height: 2.5em; transform-origin: 33% 10%; } .hamster__limb--br { animation: hamsterBRLimb var(--dur) linear infinite; background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%); transform: rotate(-15deg) translateZ(-1px); } .hamster__limb--bl { animation: hamsterBLLimb var(--dur) linear infinite; background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%); transform: rotate(60deg) translateZ(-1px); } .hamster__tail { animation: hamsterTail var(--dur) linear infinite; background: hsl(0,90%,85%); border-radius: 0.25em 50% 50% 0.25em; box-shadow: 0.25em 0 hsl(30,90%,55%) inset; top: 1.5em; left: 5.5em; width: 0.5em; height: 0.75em; transform: rotate(30deg) translateZ(-1px); transform-origin: 0.25em 0.125em; } .spoke { background: radial-gradient(hsl(0,0%,70%) 25%,hsla(0,0%,60%,0) 26%) center/8px 8px; z-index: 0; } .spoke--1 { animation: spoke var(--dur) linear infinite; } .spoke--2 { animation: spoke var(--dur) linear infinite; transform: rotate(30deg); } .spoke--3 { animation: spoke var(--dur) linear infinite; transform: rotate(60deg); } @keyframes hamster { 0%,100% { transform: rotate(4deg) translate(-0.8em,1.85em) } 50% { transform: rotate(0) translate(-0.8em,1.85em) } } @keyframes hamsterHead { 0%,100% { transform: rotate(0) } 50% { transform: rotate(-8deg) } } @keyframes hamsterEar { 0%,100% { transform: rotate(0) } 50% { transform: rotate(-3deg) } } @keyframes hamsterEye { 0%,90%,100% { transform: scaleY(1) } 95% { transform: scaleY(0) } } @keyframes hamsterBody { 0%,100% { transform: rotate(0) } 50% { transform: rotate(2deg) } } @keyframes hamsterFRLimb { 0%,100% { transform: rotate(15deg) translateZ(-1px) } 50% { transform: rotate(-30deg) translateZ(-1px) } } @keyframes hamsterFLLimb { 0%,100% { transform: rotate(-60deg) translateZ(-1px) } 50% { transform: rotate(-25deg) translateZ(-1px) } } @keyframes hamsterBRLimb { 0%,100% { transform: rotate(-15deg) translateZ(-1px) } 50% { transform: rotate(30deg) translateZ(-1px) } } @keyframes hamsterBLLimb { 0%,100% { transform: rotate(60deg) translateZ(-1px) } 50% { transform: rotate(25deg) translateZ(-1px) } } @keyframes hamsterTail { 0%,100% { transform: rotate(30deg) translateZ(-1px) } 50% { transform: rotate(10deg) translateZ(-1px) } } @keyframes spoke { 0% { transform: rotate(0) } 100% { transform: rotate(1turn) } } ''' html = '''
''' with gr.Blocks(css=css) as demo: gr.HTML(html) gr.Markdown("# Flash Attention with SDXL") gr.Markdown("Generate images with Flash Attention and SDXL") with gr.Row(): with gr.Column(scale=55): prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt").style(container=False) negative_prompt = gr.Textbox(label="Negative Prompt", show_label=False, max_lines=2, placeholder="Enter negative prompt").style(container=False) with gr.Column(scale=45): generate_btn = gr.Button("Generate") with gr.Row(): image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto") seed_output = gr.Number(label="Seed Used") gr.Examples(examples=examples, inputs=[prompt]) inputs = [prompt, negative_prompt, gr.Checkbox(False, label="Use Negative Prompt"), gr.Slider(1, MAX_SEED, value=1, label="Seed"), gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Width"), gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Height"), gr.Slider(1, 20, value=7.5, label="Guidance Scale"), gr.Slider(1, 100, value=30, label="Number of Inference Steps"), gr.Checkbox(False, label="Randomize Seed"), gr.Checkbox(True, label="Use Resolution Binning"), gr.Slider(1, 10, value=1, label="Number of Images")] generate_btn.click(fn=generate, inputs=inputs, outputs=[image_output, seed_output]) demo.launch()