STABLE-HAMSTER / app.py
prithivMLmods's picture
Update app.py
45c21ee verified
raw
history blame
11.3 kB
import os
import random
import uuid
import json
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
# Use environment variables for flexibility
MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
# Determine device and load model outside of function for efficiency
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True,
add_watermarker=False,
).to(device)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
# Torch compile for potential speedup (experimental)
if USE_TORCH_COMPILE:
pipe.compile()
# CPU offloading for larger RAM capacity (experimental)
if ENABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
MAX_SEED = np.iinfo(np.int32).max
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=35, enable_queue=True)
def generate(
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 1,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 3,
num_inference_steps: int = 30,
randomize_seed: bool = False,
use_resolution_binning: bool = True,
num_images: int = 1, # Number of images to generate
progress=gr.Progress(track_tqdm=True),
):
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator(device=device).manual_seed(seed)
# Improved options handling
options = {
"prompt": [prompt] * num_images,
"negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
"output_type": "pil",
}
# Use resolution binning for faster generation with less VRAM usage
if use_resolution_binning:
options["use_resolution_binning"] = True
# Generate images potentially in batches
images = []
for i in range(0, num_images, BATCH_SIZE):
batch_options = options.copy()
batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
if "negative_prompt" in batch_options:
batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
images.extend(pipe(**batch_options).images)
image_paths = [save_image(img) for img in images]
return image_paths, seed
examples = [
"a cat eating a piece of cheese",
"a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
"Ironman VS Hulk, ultrarealistic",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
"Kids going to school, Anime style"
]
css = '''
.gradio-container{max-width: 700px !important}
h1{text-align:center}
footer {
visibility: hidden
}
.wheel-and-hamster {
--dur: 1s;
position: relative;
width: 12em;
height: 12em;
font-size: 14px;
}
.wheel,
.hamster,
.hamster div,
.spoke {
position: absolute;
}
.wheel,
.spoke {
border-radius: 50%;
top: 0;
left: 0;
width: 100%;
height: 100%;
}
.wheel {
background: radial-gradient(100% 100% at center,hsla(0,0%,60%,0) 47.8%,hsl(0,0%,60%) 48%);
z-index: 2;
}
.hamster {
animation: hamster var(--dur) ease-in-out infinite;
top: 50%;
left: calc(50% - 3.5em);
width: 7em;
height: 3.75em;
transform: rotate(4deg) translate(-0.8em,1.85em);
transform-origin: 50% 0;
z-index: 1;
}
.hamster__head {
animation: hamsterHead var(--dur) ease-in-out infinite;
background: hsl(30,90%,55%);
border-radius: 70% 30% 0 100% / 40% 25% 25% 60%;
box-shadow: 0 -0.25em 0 hsl(30,90%,80%) inset,
0.75em -1.55em 0 hsl(30,90%,90%) inset;
top: 0;
left: -2em;
width: 2.75em;
height: 2.5em;
transform-origin: 100% 50%;
}
.hamster__ear {
animation: hamsterEar var(--dur) ease-in-out infinite;
background: hsl(0,90%,85%);
border-radius: 50%;
box-shadow: -0.25em 0 hsl(30,90%,55%) inset;
top: -0.25em;
right: -0.25em;
width: 0.75em;
height: 0.75em;
transform-origin: 50% 75%;
}
.hamster__eye {
animation: hamsterEye var(--dur) linear infinite;
background-color: hsl(0,0%,0%);
border-radius: 50%;
top: 0.375em;
left: 1.25em;
width: 0.5em;
height: 0.5em;
}
.hamster__nose {
background: hsl(0,90%,75%);
border-radius: 35% 65% 85% 15% / 70% 50% 50% 30%;
top: 0.75em;
left: 0;
width: 0.2em;
height: 0.25em;
}
.hamster__body {
animation: hamsterBody var(--dur) ease-in-out infinite;
background: hsl(30,90%,90%);
border-radius: 50% 30% 50% 30% / 15% 60% 40% 40%;
box-shadow: 0.1em 0.75em 0 hsl(30,90%,55%) inset,
0.15em -0.5em 0 hsl(30,90%,80%) inset;
top: 0.25em;
left: 2em;
width: 4.5em;
height: 3em;
transform-origin: 17% 50%;
transform-style: preserve-3d;
}
.hamster__limb--fr,
.hamster__limb--fl {
clip-path: polygon(0 0,100% 0,70% 80%,60% 100%,0% 100%,40% 80%);
top: 2em;
left: 0.5em;
width: 1em;
height: 1.5em;
transform-origin: 50% 0;
}
.hamster__limb--fr {
animation: hamsterFRLimb var(--dur) linear infinite;
background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%);
transform: rotate(15deg) translateZ(-1px);
}
.hamster__limb--fl {
animation: hamsterFLLimb var(--dur) linear infinite;
background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%);
transform: rotate(-60deg) translateZ(-1px);
}
.hamster__limb--br,
.hamster__limb--bl {
clip-path: polygon(0 0,100% 0,100% 20%,30% 100%,0% 100%);
top: 1.25em;
left: 2.8em;
width: 1.5em;
height: 2.5em;
transform-origin: 33% 10%;
}
.hamster__limb--br {
animation: hamsterBRLimb var(--dur) linear infinite;
background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%);
transform: rotate(-15deg) translateZ(-1px);
}
.hamster__limb--bl {
animation: hamsterBLLimb var(--dur) linear infinite;
background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%);
transform: rotate(60deg) translateZ(-1px);
}
.hamster__tail {
animation: hamsterTail var(--dur) linear infinite;
background: hsl(0,90%,85%);
border-radius: 0.25em 50% 50% 0.25em;
box-shadow: 0.25em 0 hsl(30,90%,55%) inset;
top: 1.5em;
left: 5.5em;
width: 0.5em;
height: 0.75em;
transform: rotate(30deg) translateZ(-1px);
transform-origin: 0.25em 0.125em;
}
.spoke {
background: radial-gradient(hsl(0,0%,70%) 25%,hsla(0,0%,60%,0) 26%) center/8px 8px;
z-index: 0;
}
.spoke--1 {
animation: spoke var(--dur) linear infinite;
}
.spoke--2 {
animation: spoke var(--dur) linear infinite;
transform: rotate(30deg);
}
.spoke--3 {
animation: spoke var(--dur) linear infinite;
transform: rotate(60deg);
}
@keyframes hamster {
0%,100% { transform: rotate(4deg) translate(-0.8em,1.85em) }
50% { transform: rotate(0) translate(-0.8em,1.85em) }
}
@keyframes hamsterHead {
0%,100% { transform: rotate(0) }
50% { transform: rotate(-8deg) }
}
@keyframes hamsterEar {
0%,100% { transform: rotate(0) }
50% { transform: rotate(-3deg) }
}
@keyframes hamsterEye {
0%,90%,100% { transform: scaleY(1) }
95% { transform: scaleY(0) }
}
@keyframes hamsterBody {
0%,100% { transform: rotate(0) }
50% { transform: rotate(2deg) }
}
@keyframes hamsterFRLimb {
0%,100% { transform: rotate(15deg) translateZ(-1px) }
50% { transform: rotate(-30deg) translateZ(-1px) }
}
@keyframes hamsterFLLimb {
0%,100% { transform: rotate(-60deg) translateZ(-1px) }
50% { transform: rotate(-25deg) translateZ(-1px) }
}
@keyframes hamsterBRLimb {
0%,100% { transform: rotate(-15deg) translateZ(-1px) }
50% { transform: rotate(30deg) translateZ(-1px) }
}
@keyframes hamsterBLLimb {
0%,100% { transform: rotate(60deg) translateZ(-1px) }
50% { transform: rotate(25deg) translateZ(-1px) }
}
@keyframes hamsterTail {
0%,100% { transform: rotate(30deg) translateZ(-1px) }
50% { transform: rotate(10deg) translateZ(-1px) }
}
@keyframes spoke {
0% { transform: rotate(0) }
100% { transform: rotate(1turn) }
}
'''
html = '''
<div id="loading-animation" style="display: flex; justify-content: center; align-items: center; height: 100vh;">
<div class="wheel-and-hamster">
<div class="wheel"></div>
<div class="hamster">
<div class="hamster__body">
<div class="hamster__head">
<div class="hamster__ear"></div>
<div class="hamster__eye"></div>
<div class="hamster__nose"></div>
</div>
<div class="hamster__limb hamster__limb--fr"></div>
<div class="hamster__limb hamster__limb--fl"></div>
<div class="hamster__limb hamster__limb--br"></div>
<div class="hamster__limb hamster__limb--bl"></div>
<div class="hamster__tail"></div>
</div>
</div>
<div class="spoke spoke--1"></div>
<div class="spoke spoke--2"></div>
<div class="spoke spoke--3"></div>
</div>
</div>
<script>
window.onload = function() {
document.getElementById("loading-animation").style.display = "none";
}
</script>
'''
with gr.Blocks(css=css) as demo:
gr.HTML(html)
gr.Markdown("# Flash Attention with SDXL")
gr.Markdown("Generate images with Flash Attention and SDXL")
with gr.Row():
with gr.Column(scale=55):
prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt").style(container=False)
negative_prompt = gr.Textbox(label="Negative Prompt", show_label=False, max_lines=2, placeholder="Enter negative prompt").style(container=False)
with gr.Column(scale=45):
generate_btn = gr.Button("Generate")
with gr.Row():
image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto")
seed_output = gr.Number(label="Seed Used")
gr.Examples(examples=examples, inputs=[prompt])
inputs = [prompt, negative_prompt, gr.Checkbox(False, label="Use Negative Prompt"), gr.Slider(1, MAX_SEED, value=1, label="Seed"),
gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Width"), gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Height"),
gr.Slider(1, 20, value=7.5, label="Guidance Scale"), gr.Slider(1, 100, value=30, label="Number of Inference Steps"),
gr.Checkbox(False, label="Randomize Seed"), gr.Checkbox(True, label="Use Resolution Binning"),
gr.Slider(1, 10, value=1, label="Number of Images")]
generate_btn.click(fn=generate, inputs=inputs, outputs=[image_output, seed_output])
demo.launch()