File size: 6,852 Bytes
248bc06
 
49901fa
b14c223
248bc06
b14c223
ce4b9eb
248bc06
 
 
b14c223
 
 
 
 
 
 
 
 
 
 
 
248bc06
 
 
 
 
49901fa
248bc06
49901fa
 
 
 
 
 
 
 
 
248bc06
49901fa
 
 
 
 
 
 
 
248bc06
ce4b9eb
b14c223
 
 
 
f1a6530
 
b14c223
 
 
248bc06
b14c223
248bc06
b14c223
248bc06
b14c223
248bc06
b14c223
248bc06
b14c223
 
 
 
 
 
 
 
 
 
 
 
248bc06
 
f1a6530
248bc06
 
60d647a
 
 
 
 
 
248bc06
 
 
b14c223
 
 
 
 
 
248bc06
 
 
 
 
 
 
b14c223
 
f1a6530
8fbd147
 
 
 
b14c223
 
248bc06
 
 
b14c223
248bc06
 
 
f1a6530
248bc06
b14c223
 
 
 
 
f1a6530
b14c223
 
f1a6530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14c223
 
 
 
 
 
 
248bc06
b14c223
 
 
 
 
 
611e66d
 
 
 
 
 
 
 
 
 
 
b14c223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1a6530
b14c223
248bc06
f1a6530
248bc06
b14c223
 
248bc06
611e66d
f1a6530
b14c223
 
 
 
 
f1a6530
b14c223
248bc06
 
b14c223
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import torch
import devicetorch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler
from diffusers.schedulers import TCDScheduler

#import spaces
from PIL import Image

checkpoints = {
    "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
    "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
    "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
    "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
    "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
    "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
    "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
    "LCM-Like LoRA": [
        "pcm_{}_lcmlike_lora_converted.safetensors",
        4,
        0.0,
    ],
}


loaded = None

device = devicetorch.get(torch)

#if torch.cuda.is_available():
#    pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
#        "stabilityai/stable-diffusion-xl-base-1.0",
#        torch_dtype=torch.float16,
#        variant="fp16",
#    ).to("cuda")
#    pipe_sd15 = StableDiffusionPipeline.from_pretrained(
#        "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
#    ).to("cuda")

pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
).to(device)
pipe_sd15 = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
).to(device)

#@spaces.GPU(enable_queue=True)
def generate_image(
    prompt,
    ckpt,
    num_inference_steps,
    width,
    height,
    progress=gr.Progress(track_tqdm=True),
    mode="sdxl",
):
    global loaded
    checkpoint = checkpoints[ckpt][0].format(mode)
    guidance_scale = checkpoints[ckpt][2]
    pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15

    if loaded != (ckpt + mode):
        pipe.load_lora_weights(
            "wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode
        )
        loaded = ckpt + mode

        if ckpt == "LCM-Like LoRA":
            pipe.scheduler = LCMScheduler()
        else:
            pipe.scheduler = TCDScheduler(
                num_train_timesteps=1000,
                beta_start=0.00085,
                beta_end=0.012,
                beta_schedule="scaled_linear",
                timestep_spacing="trailing",
            )

    results = pipe(
        prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height
    )

#    if SAFETY_CHECKER:
#        images, has_nsfw_concepts = check_nsfw_images(results.images)
#        if any(has_nsfw_concepts):
#            gr.Warning("NSFW content detected.")
#            return Image.new("RGB", (512, 512))
#        return images[0]
    return results.images[0]


def update_steps(ckpt):
    num_inference_steps = checkpoints[ckpt][1]
    if ckpt == "LCM-Like LoRA":
        return gr.update(interactive=True, value=num_inference_steps)
    return gr.update(interactive=False, value=num_inference_steps)


css = """
.gradio-container {
  max-width: 60rem !important;
}
"""
with gr.Blocks(css=css) as demo:
    gr.Markdown(
        """
# Phased Consistency Model

Phased Consistency Model (PCM) is an image generation technique that addresses the limitations of the Latent Consistency Model (LCM) in high-resolution and text-conditioned image generation.
PCM outperforms LCM across various generation settings and achieves state-of-the-art results in both image and video generation.

[[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)]  [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)]
"""
    )
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label="Prompt", scale=8)
            ckpt = gr.Dropdown(
                label="Select inference steps",
                choices=list(checkpoints.keys()),
                value="2-Step",
            )
            steps = gr.Slider(
                label="Number of Inference Steps",
                minimum=1,
                maximum=20,
                step=1,
                value=2,
                interactive=False,
            )
            width = gr.Slider(
                label="Width",
                minimum=512,
                maximum=1024,
                step=256,
                value=512,
                interactive=True
            )
            height = gr.Slider(
                label="Height",
                minimum=512,
                maximum=1024,
                step=256,
                value=512,
                interactive=True
            )
            ckpt.change(
                fn=update_steps,
                inputs=[ckpt],
                outputs=[steps],
                queue=False,
                show_progress=False,
            )

            submit_sdxl = gr.Button("Run on SDXL", scale=1)
            submit_sd15 = gr.Button("Run on SD15", scale=1)

    img = gr.Image(label="PCM Image")
    gr.Examples(
        examples=[
            [" astronaut walking on the moon", "4-Step", 4],
            [
                "Photo of a dramatic cliffside lighthouse in a storm, waves crashing, symbol of guidance and resilience.",
                "8-Step",
                8,
            ],
            [
                "Vincent vangogh style, painting, a boy, clouds in the sky",
                "Normal CFG 4-Step",
                4,
            ],
            [
                "Echoes of a forgotten song drift across the moonlit sea, where a ghost ship sails, its spectral crew bound to an eternal quest for redemption.",
                "4-Step",
                4,
            ],
            [
                "Roger rabbit as a real person, photorealistic, cinematic.",
                "16-Step",
                16,
            ],
            [
                "tanding tall amidst the ruins, a stone golem awakens, vines and flowers sprouting from the crevices in its body.",
                "LCM-Like LoRA",
                4,
            ],
        ],
        inputs=[prompt, ckpt, steps, width, height],
        outputs=[img],
        fn=generate_image,
        #cache_examples="lazy",
    )

    gr.on(
        fn=generate_image,
        triggers=[ckpt.change, prompt.submit, submit_sdxl.click],
        inputs=[prompt, ckpt, steps, width, height],
        outputs=[img],
    )
    gr.on(
        fn=lambda *args: generate_image(*args, mode="sd15"),
        triggers=[submit_sd15.click],
        inputs=[prompt, ckpt, steps, width, height],
        outputs=[img],
    )


demo.queue(api_open=False).launch(show_api=False)