File size: 4,404 Bytes
c2c42ca
d827a95
81435cb
61bc6a3
 
6da6b11
9d41bd5
a1f66f7
 
c2c42ca
aa5a24b
 
 
 
 
91dd651
c2c42ca
61bc6a3
 
 
 
c2c42ca
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
c2c42ca
61bc6a3
 
 
143f063
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
aa5a24b
91dd651
6da6b11
825bfd6
61bc6a3
 
 
 
143f063
61bc6a3
dc81866
c2c42ca
 
 
dc81866
 
 
 
 
 
 
 
 
34cb1b5
 
5e64d98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder=pipe_turbo.text_encoder,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder=pipe_turbo.text_encoder,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet

@spaces.GPU
def run_comparison(prompt):
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    return image_turbo, image_lightning, image_hyper

with gr.Blocks() as demo:
    gr.Markdown("## One step SDXL comparison 🦶")
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        with gr.Column():
            image_turbo = gr.Image(label="SDXL Turbo")
            gr.Markdown("## [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo)")
        with gr.Column():
            image_lightning = gr.Image(label="SDXL Lightning")
            gr.Markdown("## [SDXL Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)")
        with gr.Column():
            image_hyper = gr.Image(label="Hyper SDXL")
            gr.Markdown("## [Hyper SDXL](https://huggingface.co/ByteDance/Hyper-SD)")
    
    run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])
demo.launch()