sdxl / src /app.py
1aurent's picture
init
a2abb01 unverified
import gradio as gr # pyright: ignore[reportMissingTypeStubs]
import pillow_heif # pyright: ignore[reportMissingTypeStubs]
import spaces # pyright: ignore[reportMissingTypeStubs]
import torch
from huggingface_hub import ( # pyright: ignore[reportMissingTypeStubs]
hf_hub_download, # pyright: ignore[reportUnknownVariableType]
)
from PIL import Image
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType]
pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType]
TITLE = """
# SDXL with Refiners
"""
# initialize the model, on the cpu
DEVICE_CPU = torch.device("cpu")
DEVICE_GPU = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
model = StableDiffusion_XL(device=DEVICE_CPU, dtype=DTYPE)
model.unet.load_from_safetensors(
tensors_path=hf_hub_download(
repo_id="refiners/sdxl.unet",
filename="model.safetensors",
revision="52a645e5b604a94a9d2b0c0e56b6ae059e80987b",
)
)
model.lda.load_from_safetensors(
tensors_path=hf_hub_download(
repo_id="refiners/sdxl.autoencoder",
filename="model.safetensors",
revision="4c2a697138e728c6d2d1e0cf3a1327181f704a2c",
)
)
model.clip_text_encoder.load_from_safetensors(
tensors_path=hf_hub_download(
repo_id="refiners/sdxl.text_encoder",
filename="model.safetensors",
revision="5c8e667196725a0e404cabf51fca8d3cda2436fa",
)
)
# "move" the model to the gpu, this is handled/intercepted by Zero GPU
model.to(device=DEVICE_GPU, dtype=DTYPE)
model.unet.to(device=DEVICE_GPU, dtype=DTYPE)
model.lda.to(device=DEVICE_GPU, dtype=DTYPE)
model.clip_text_encoder.to(device=DEVICE_GPU, dtype=DTYPE)
model.solver.to(device=DEVICE_GPU, dtype=DTYPE)
model.device = DEVICE_GPU
model.dtype = DTYPE
@spaces.GPU
@no_grad()
def process(
prompt: str,
negative_prompt: str,
condition_scale: float,
num_inference_steps: int,
seed: int,
) -> Image.Image:
assert condition_scale >= 0
assert num_inference_steps > 0
assert seed >= 0
# set the seed
manual_seed(seed)
# compute embeddings
clip_text_embedding, pooled_text_embedding = model.compute_clip_text_embedding(
text=prompt,
negative_text=negative_prompt,
)
# get time_ids
time_ids = model.default_time_ids
# init latents
x = model.init_latents(size=(1024, 1024))
# denoise latents
for step in model.steps:
x = model(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
condition_scale=condition_scale,
time_ids=time_ids,
)
# decode denoised latents
image = model.lda.latents_to_image(x)
return image
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Column():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button(
value="Run",
scale=0,
)
output_image = gr.Image(
label="Output Image",
image_mode="RGB",
type="pil",
)
with gr.Accordion("Advanced Settings", open=True):
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Enter your (optional) negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
value=2,
step=1,
)
condition_scale = gr.Slider(
label="Condition scale",
minimum=0,
maximum=20,
value=5,
step=0.05,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
value=30,
step=1,
)
run_button.click(
fn=process,
inputs=[
prompt,
negative_prompt,
condition_scale,
num_inference_steps,
seed,
],
outputs=output_image,
)
gr.Examples( # pyright: ignore[reportUnknownMemberType]
examples=[
[
"a cute cat, detailed high-quality professional image",
"lowres, bad anatomy, bad hands, cropped, worst quality",
5.0,
30,
2,
],
[
"a cute dog, detailed high-quality professional image",
"lowres, bad anatomy, bad hands, cropped, worst quality",
5.0,
30,
2,
],
],
inputs=[
prompt,
negative_prompt,
condition_scale,
num_inference_steps,
seed,
],
outputs=output_image,
fn=process,
cache_examples=True,
cache_mode="lazy",
run_on_click=False,
)
demo.launch()