File size: 5,630 Bytes
c2c42ca
d827a95
81435cb
61bc6a3
 
6da6b11
9d41bd5
a1f66f7
 
c2c42ca
aa5a24b
 
 
 
 
91dd651
c2c42ca
61bc6a3
 
 
 
c2c42ca
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
c2c42ca
61bc6a3
 
 
143f063
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
aa5a24b
91dd651
6da6b11
e3c765a
61bc6a3
e3c765a
 
 
 
61bc6a3
e3c765a
 
 
 
61bc6a3
e3c765a
143f063
b634b72
ddbaa70
 
 
 
 
 
 
 
61bc6a3
dc81866
d452942
c2c42ca
 
 
dc81866
 
 
 
 
 
 
 
 
ddbaa70
bc87ae3
 
e3c765a
 
 
 
 
 
 
 
 
bc87ae3
e3c765a
ddbaa70
 
 
 
4835fe3
e3c765a
bc87ae3
5e64d98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder=pipe_turbo.text_encoder,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder=pipe_turbo.text_encoder,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet

@spaces.GPU
def run_comparison_turbo(prompt, progress=gr.Progress(track_tqdm=True)):
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    return image_turbo

@spaces.GPU
def run_comparison_lightning(prompt, progress=gr.Progress(track_tqdm=True)):
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    return image_lightning

@spaces.GPU
def run_comparison_hyper(prompt, progress=gr.Progress(track_tqdm=True)):
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    return image_hyper

examples = ["A dignified beaver wearing glasses, a vest, and colorful neck tie.",
"The spirit of a tamagotchi wandering in the city of Barcelona",
"an ornate, high-backed mahogany chair with a red cushion",
"a sketch of a camel next to a stream",
"a delicate porcelain teacup sits on a saucer, its surface adorned with intricate blue patterns",
"a baby swan grafitti",
"A bald eagle made of chocolate powder, mango, and whipped cream"
]

with gr.Blocks() as demo:
    gr.Markdown("## One step SDXL comparison 🦶")
    gr.Markdown('Compare SDXL variants and distillations able to generate images in a single diffusion step')
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        with gr.Column():
            image_turbo = gr.Image(label="SDXL Turbo")
            gr.Markdown("## [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo)")
        with gr.Column():
            image_lightning = gr.Image(label="SDXL Lightning")
            gr.Markdown("## [SDXL Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)")
        with gr.Column():
            image_hyper = gr.Image(label="Hyper SDXL")
            gr.Markdown("## [Hyper SDXL](https://huggingface.co/ByteDance/Hyper-SD)")
    image_outputs = [image_turbo, image_lightning, image_hyper]
    gr.on(
        triggers=[prompt.submit, run.click],
        fn=run_comparison_turbo,
        inputs=prompt,
        outputs=image_turbo
    ).then(
        fn=run_comparison_lightning,
        inputs=prompt,
        outputs=image_lightning
    ).then(
        fn=run_comparison_hyper,
        inputs=prompt,
        outputs=image_hyper
    )
    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=image_outputs,
        cache_examples=False
    )
demo.launch()