|
import gradio as gr |
|
import pillow_heif |
|
import spaces |
|
import torch |
|
from huggingface_hub import ( |
|
hf_hub_download, |
|
) |
|
from PIL import Image |
|
from refiners.fluxion.utils import manual_seed, no_grad |
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL |
|
|
|
pillow_heif.register_heif_opener() |
|
pillow_heif.register_avif_opener() |
|
|
|
TITLE = """ |
|
# SDXL with Refiners |
|
""" |
|
|
|
|
|
DEVICE_CPU = torch.device("cpu") |
|
DEVICE_GPU = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 |
|
|
|
model = StableDiffusion_XL(device=DEVICE_CPU, dtype=DTYPE) |
|
model.unet.load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sdxl.unet", |
|
filename="model.safetensors", |
|
revision="52a645e5b604a94a9d2b0c0e56b6ae059e80987b", |
|
) |
|
) |
|
model.lda.load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sdxl.autoencoder", |
|
filename="model.safetensors", |
|
revision="4c2a697138e728c6d2d1e0cf3a1327181f704a2c", |
|
) |
|
) |
|
model.clip_text_encoder.load_from_safetensors( |
|
tensors_path=hf_hub_download( |
|
repo_id="refiners/sdxl.text_encoder", |
|
filename="model.safetensors", |
|
revision="5c8e667196725a0e404cabf51fca8d3cda2436fa", |
|
) |
|
) |
|
|
|
|
|
model.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.unet.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.lda.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.clip_text_encoder.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.solver.to(device=DEVICE_GPU, dtype=DTYPE) |
|
model.device = DEVICE_GPU |
|
model.dtype = DTYPE |
|
|
|
|
|
@spaces.GPU |
|
@no_grad() |
|
def process( |
|
prompt: str, |
|
negative_prompt: str, |
|
condition_scale: float, |
|
num_inference_steps: int, |
|
seed: int, |
|
) -> Image.Image: |
|
assert condition_scale >= 0 |
|
assert num_inference_steps > 0 |
|
assert seed >= 0 |
|
|
|
|
|
manual_seed(seed) |
|
|
|
|
|
clip_text_embedding, pooled_text_embedding = model.compute_clip_text_embedding( |
|
text=prompt, |
|
negative_text=negative_prompt, |
|
) |
|
|
|
|
|
time_ids = model.default_time_ids |
|
|
|
|
|
x = model.init_latents(size=(1024, 1024)) |
|
|
|
|
|
for step in model.steps: |
|
x = model( |
|
x, |
|
step=step, |
|
clip_text_embedding=clip_text_embedding, |
|
pooled_text_embedding=pooled_text_embedding, |
|
condition_scale=condition_scale, |
|
time_ids=time_ids, |
|
) |
|
|
|
|
|
image = model.lda.latents_to_image(x) |
|
|
|
return image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
prompt = gr.Text( |
|
label="Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt", |
|
container=False, |
|
) |
|
run_button = gr.Button( |
|
value="Run", |
|
scale=0, |
|
) |
|
|
|
output_image = gr.Image( |
|
label="Output Image", |
|
image_mode="RGB", |
|
type="pil", |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=True): |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="Enter your (optional) negative prompt", |
|
) |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=100_000, |
|
value=2, |
|
step=1, |
|
) |
|
condition_scale = gr.Slider( |
|
label="Condition scale", |
|
minimum=0, |
|
maximum=20, |
|
value=5, |
|
step=0.05, |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=50, |
|
value=30, |
|
step=1, |
|
) |
|
|
|
run_button.click( |
|
fn=process, |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
condition_scale, |
|
num_inference_steps, |
|
seed, |
|
], |
|
outputs=output_image, |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"a cute cat, detailed high-quality professional image", |
|
"lowres, bad anatomy, bad hands, cropped, worst quality", |
|
5.0, |
|
30, |
|
2, |
|
], |
|
[ |
|
"a cute dog, detailed high-quality professional image", |
|
"lowres, bad anatomy, bad hands, cropped, worst quality", |
|
5.0, |
|
30, |
|
2, |
|
], |
|
], |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
condition_scale, |
|
num_inference_steps, |
|
seed, |
|
], |
|
outputs=output_image, |
|
fn=process, |
|
cache_examples=True, |
|
cache_mode="lazy", |
|
run_on_click=False, |
|
) |
|
|
|
demo.launch() |
|
|