dbaranchuk's picture
Update app.py
7418952 verified
raw
history blame
5.43 kB
import spaces
import gradio as gr
import numpy as np
import random
import generation_sdxl
import functools
from diffusers import DiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline, DDIMScheduler
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.max_memory_allocated(device=device)
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,
torch_dtype=torch.float16,
scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),
variant="fp16").to(device)
pipe = pipe.to(device)
unet = UNet2DConditionModel.from_pretrained("dbaranchuk/sdxl-cfg-distill-unet").to(device)
pipe.unet = unet
pipe.load_lora_weights("dbaranchuk/icd-lora-sdxl",
weight_name='reverse-249-499-699-999.safetensors')
pipe.fuse_lora()
pipe.to(dtype=torch.float16, device=device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU()
def infer(prompt, seed, randomize_seed, tau,
guidance_scale):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
prompt = [prompt]
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
compute_embeddings_fn = functools.partial(
generation_sdxl.compute_embeddings,
proportion_empty_prompts=0,
text_encoders=text_encoders,
tokenizers=tokenizers,
)
if tau < 1.0:
use_dynamic_guidance=True
else:
use_dynamic_guidance=False
images = generation_sdxl.sample_deterministic(
pipe,
prompt,
num_inference_steps=4,
generator=generator,
guidance_scale=guidance_scale,
is_sdxl=True,
timesteps=[249, 499, 699, 999],
use_dynamic_guidance=use_dynamic_guidance,
tau1=tau,
tau2=tau,
compute_embeddings_fn=compute_embeddings_fn
)[0]
return images
examples = [
"An astronaut riding a green horse",
'Long-exposure night photography of a starry sky over a mountain range, with light trails.',
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A portrait of a girl with blonde, tousled hair, blue eyes",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
f"""
# ⚡ Invertible Consistency Distillation ⚡
# ⚡ Image Generation with 4-step iCD-XL ⚡
This is a demo of [Invertible Consistency Distillation](https://yandex-research.github.io/invertible-cd/),
a diffusion distillation method proposed in [Invertible Consistency Distillation for Text-Guided Image Editing in Around 7 Steps](https://arxiv.org/abs/2406.14539)
by [Yandex Research](https://github.com/yandex-research).
Currently running on {power_device}.
"""
)
gr.Markdown(
"Feel free to check out our [image editing demo](https://huggingface.co/spaces/dbaranchuk/iCD-image-editing) as well."
)
gr.Markdown(
"If you enjoy the space, feel free to give a ⭐ to the <a href='https://github.com/yandex-research/invertible-cd' target='_blank'>Github Repo</a>. [![GitHub Stars](https://img.shields.io/github/stars/yandex-research/invertible-cd?style=social)](https://github.com/yandex-research/invertible-cd)"
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=19.0,
step=1.0,
value=7.0,
)
dynamic_guidance_tau = gr.Slider(
label="Dynamic guidance tau",
minimum=0,
maximum=1,
step=0.1,
value=1.0,
)
gr.Examples(
examples = examples,
inputs = [prompt],
cache_examples=False
)
run_button.click(
fn = infer,
inputs = [prompt, seed, randomize_seed, dynamic_guidance_tau, guidance_scale],
outputs = [result]
)
demo.queue().launch(share=False)