Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import json | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import spaces | |
import torch | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
# Use environment variables for flexibility | |
MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash") | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) | |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" | |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" | |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once | |
# Determine device and load model outside of function for efficiency | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=True, | |
add_watermarker=False, | |
).to(device) | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
# Torch compile for potential speedup (experimental) | |
if USE_TORCH_COMPILE: | |
pipe.compile() | |
# CPU offloading for larger RAM capacity (experimental) | |
if ENABLE_CPU_OFFLOAD: | |
pipe.enable_model_cpu_offload() | |
MAX_SEED = np.iinfo(np.int32).max | |
def save_image(img): | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def generate( | |
prompt: str, | |
negative_prompt: str = "", | |
use_negative_prompt: bool = False, | |
seed: int = 1, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 3, | |
num_inference_steps: int = 30, | |
randomize_seed: bool = False, | |
use_resolution_binning: bool = True, | |
num_images: int = 1, # Number of images to generate | |
progress=gr.Progress(track_tqdm=True), | |
): | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Improved options handling | |
options = { | |
"prompt": [prompt] * num_images, | |
"negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None, | |
"width": width, | |
"height": height, | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": num_inference_steps, | |
"generator": generator, | |
"output_type": "pil", | |
} | |
# Use resolution binning for faster generation with less VRAM usage | |
if use_resolution_binning: | |
options["use_resolution_binning"] = True | |
# Generate images potentially in batches | |
images = [] | |
for i in range(0, num_images, BATCH_SIZE): | |
batch_options = options.copy() | |
batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE] | |
if "negative_prompt" in batch_options: | |
batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE] | |
images.extend(pipe(**batch_options).images) | |
image_paths = [save_image(img) for img in images] | |
return image_paths, seed | |
examples = [ | |
"a cat eating a piece of cheese", | |
"a ROBOT riding a BLUE horse on Mars, photorealistic, 4k", | |
"Ironman VS Hulk, ultrarealistic", | |
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", | |
"An alien holding a sign board containing the word 'Flash', futuristic, neonpunk", | |
"Kids going to school, Anime style" | |
] | |
css = ''' | |
.gradio-container{max-width: 700px !important} | |
h1{text-align:center} | |
footer { | |
visibility: hidden | |
} | |
.wheel-and-hamster { | |
--dur: 1s; | |
position: relative; | |
width: 12em; | |
height: 12em; | |
font-size: 14px; | |
} | |
.wheel, | |
.hamster, | |
.hamster div, | |
.spoke { | |
position: absolute; | |
} | |
.wheel, | |
.spoke { | |
border-radius: 50%; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
} | |
.wheel { | |
background: radial-gradient(100% 100% at center,hsla(0,0%,60%,0) 47.8%,hsl(0,0%,60%) 48%); | |
z-index: 2; | |
} | |
.hamster { | |
animation: hamster var(--dur) ease-in-out infinite; | |
top: 50%; | |
left: calc(50% - 3.5em); | |
width: 7em; | |
height: 3.75em; | |
transform: rotate(4deg) translate(-0.8em,1.85em); | |
transform-origin: 50% 0; | |
z-index: 1; | |
} | |
.hamster__head { | |
animation: hamsterHead var(--dur) ease-in-out infinite; | |
background: hsl(30,90%,55%); | |
border-radius: 70% 30% 0 100% / 40% 25% 25% 60%; | |
box-shadow: 0 -0.25em 0 hsl(30,90%,80%) inset, | |
0.75em -1.55em 0 hsl(30,90%,90%) inset; | |
top: 0; | |
left: -2em; | |
width: 2.75em; | |
height: 2.5em; | |
transform-origin: 100% 50%; | |
} | |
.hamster__ear { | |
animation: hamsterEar var(--dur) ease-in-out infinite; | |
background: hsl(0,90%,85%); | |
border-radius: 50%; | |
box-shadow: -0.25em 0 hsl(30,90%,55%) inset; | |
top: -0.25em; | |
right: -0.25em; | |
width: 0.75em; | |
height: 0.75em; | |
transform-origin: 50% 75%; | |
} | |
.hamster__eye { | |
animation: hamsterEye var(--dur) linear infinite; | |
background-color: hsl(0,0%,0%); | |
border-radius: 50%; | |
top: 0.375em; | |
left: 1.25em; | |
width: 0.5em; | |
height: 0.5em; | |
} | |
.hamster__nose { | |
background: hsl(0,90%,75%); | |
border-radius: 35% 65% 85% 15% / 70% 50% 50% 30%; | |
top: 0.75em; | |
left: 0; | |
width: 0.2em; | |
height: 0.25em; | |
} | |
.hamster__body { | |
animation: hamsterBody var(--dur) ease-in-out infinite; | |
background: hsl(30,90%,90%); | |
border-radius: 50% 30% 50% 30% / 15% 60% 40% 40%; | |
box-shadow: 0.1em 0.75em 0 hsl(30,90%,55%) inset, | |
0.15em -0.5em 0 hsl(30,90%,80%) inset; | |
top: 0.25em; | |
left: 2em; | |
width: 4.5em; | |
height: 3em; | |
transform-origin: 17% 50%; | |
transform-style: preserve-3d; | |
} | |
.hamster__limb--fr, | |
.hamster__limb--fl { | |
clip-path: polygon(0 0,100% 0,70% 80%,60% 100%,0% 100%,40% 80%); | |
top: 2em; | |
left: 0.5em; | |
width: 1em; | |
height: 1.5em; | |
transform-origin: 50% 0; | |
} | |
.hamster__limb--fr { | |
animation: hamsterFRLimb var(--dur) linear infinite; | |
background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%); | |
transform: rotate(15deg) translateZ(-1px); | |
} | |
.hamster__limb--fl { | |
animation: hamsterFLLimb var(--dur) linear infinite; | |
background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%); | |
transform: rotate(-60deg) translateZ(-1px); | |
} | |
.hamster__limb--br, | |
.hamster__limb--bl { | |
clip-path: polygon(0 0,100% 0,100% 20%,30% 100%,0% 100%); | |
top: 1.25em; | |
left: 2.8em; | |
width: 1.5em; | |
height: 2.5em; | |
transform-origin: 33% 10%; | |
} | |
.hamster__limb--br { | |
animation: hamsterBRLimb var(--dur) linear infinite; | |
background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%); | |
transform: rotate(-15deg) translateZ(-1px); | |
} | |
.hamster__limb--bl { | |
animation: hamsterBLLimb var(--dur) linear infinite; | |
background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%); | |
transform: rotate(60deg) translateZ(-1px); | |
} | |
.hamster__tail { | |
animation: hamsterTail var(--dur) linear infinite; | |
background: hsl(0,90%,85%); | |
border-radius: 0.25em 50% 50% 0.25em; | |
box-shadow: 0.25em 0 hsl(30,90%,55%) inset; | |
top: 1.5em; | |
left: 5.5em; | |
width: 0.5em; | |
height: 0.75em; | |
transform: rotate(30deg) translateZ(-1px); | |
transform-origin: 0.25em 0.125em; | |
} | |
.spoke { | |
background: radial-gradient(hsl(0,0%,70%) 25%,hsla(0,0%,60%,0) 26%) center/8px 8px; | |
z-index: 0; | |
} | |
.spoke--1 { | |
animation: spoke var(--dur) linear infinite; | |
} | |
.spoke--2 { | |
animation: spoke var(--dur) linear infinite; | |
transform: rotate(30deg); | |
} | |
.spoke--3 { | |
animation: spoke var(--dur) linear infinite; | |
transform: rotate(60deg); | |
} | |
@keyframes hamster { | |
0%,100% { transform: rotate(4deg) translate(-0.8em,1.85em) } | |
50% { transform: rotate(0) translate(-0.8em,1.85em) } | |
} | |
@keyframes hamsterHead { | |
0%,100% { transform: rotate(0) } | |
50% { transform: rotate(-8deg) } | |
} | |
@keyframes hamsterEar { | |
0%,100% { transform: rotate(0) } | |
50% { transform: rotate(-3deg) } | |
} | |
@keyframes hamsterEye { | |
0%,90%,100% { transform: scaleY(1) } | |
95% { transform: scaleY(0) } | |
} | |
@keyframes hamsterBody { | |
0%,100% { transform: rotate(0) } | |
50% { transform: rotate(2deg) } | |
} | |
@keyframes hamsterFRLimb { | |
0%,100% { transform: rotate(15deg) translateZ(-1px) } | |
50% { transform: rotate(-30deg) translateZ(-1px) } | |
} | |
@keyframes hamsterFLLimb { | |
0%,100% { transform: rotate(-60deg) translateZ(-1px) } | |
50% { transform: rotate(-25deg) translateZ(-1px) } | |
} | |
@keyframes hamsterBRLimb { | |
0%,100% { transform: rotate(-15deg) translateZ(-1px) } | |
50% { transform: rotate(30deg) translateZ(-1px) } | |
} | |
@keyframes hamsterBLLimb { | |
0%,100% { transform: rotate(60deg) translateZ(-1px) } | |
50% { transform: rotate(25deg) translateZ(-1px) } | |
} | |
@keyframes hamsterTail { | |
0%,100% { transform: rotate(30deg) translateZ(-1px) } | |
50% { transform: rotate(10deg) translateZ(-1px) } | |
} | |
@keyframes spoke { | |
0% { transform: rotate(0) } | |
100% { transform: rotate(1turn) } | |
} | |
''' | |
html = ''' | |
<div id="loading-animation" style="display: flex; justify-content: center; align-items: center; height: 100vh;"> | |
<div class="wheel-and-hamster"> | |
<div class="wheel"></div> | |
<div class="hamster"> | |
<div class="hamster__body"> | |
<div class="hamster__head"> | |
<div class="hamster__ear"></div> | |
<div class="hamster__eye"></div> | |
<div class="hamster__nose"></div> | |
</div> | |
<div class="hamster__limb hamster__limb--fr"></div> | |
<div class="hamster__limb hamster__limb--fl"></div> | |
<div class="hamster__limb hamster__limb--br"></div> | |
<div class="hamster__limb hamster__limb--bl"></div> | |
<div class="hamster__tail"></div> | |
</div> | |
</div> | |
<div class="spoke spoke--1"></div> | |
<div class="spoke spoke--2"></div> | |
<div class="spoke spoke--3"></div> | |
</div> | |
</div> | |
<script> | |
window.onload = function() { | |
document.getElementById("loading-animation").style.display = "none"; | |
} | |
</script> | |
''' | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(html) | |
gr.Markdown("# Flash Attention with SDXL") | |
gr.Markdown("Generate images with Flash Attention and SDXL") | |
with gr.Row(): | |
with gr.Column(scale=55): | |
prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt").style(container=False) | |
negative_prompt = gr.Textbox(label="Negative Prompt", show_label=False, max_lines=2, placeholder="Enter negative prompt").style(container=False) | |
with gr.Column(scale=45): | |
generate_btn = gr.Button("Generate") | |
with gr.Row(): | |
image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto") | |
seed_output = gr.Number(label="Seed Used") | |
gr.Examples(examples=examples, inputs=[prompt]) | |
inputs = [prompt, negative_prompt, gr.Checkbox(False, label="Use Negative Prompt"), gr.Slider(1, MAX_SEED, value=1, label="Seed"), | |
gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Width"), gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Height"), | |
gr.Slider(1, 20, value=7.5, label="Guidance Scale"), gr.Slider(1, 100, value=30, label="Number of Inference Steps"), | |
gr.Checkbox(False, label="Randomize Seed"), gr.Checkbox(True, label="Use Resolution Binning"), | |
gr.Slider(1, 10, value=1, label="Number of Images")] | |
generate_btn.click(fn=generate, inputs=inputs, outputs=[image_output, seed_output]) | |
demo.launch() |