File size: 2,286 Bytes
fafd12c
 
f2e9ef7
fafd12c
 
 
 
 
 
 
 
 
f2e9ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fafd12c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2e9ef7
 
 
 
 
 
 
 
 
 
 
d6d99c2
 
f2e9ef7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import gradio as gr
from PIL import Image
import spaces
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline

device = "cuda"
num_images_per_prompt = 1

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade",  torch_dtype=torch.float16).to(device)

css = """
footer {
    visibility: hidden
}

#generate_button {
    color: white;
    border-color: #007bff;
    background: #2563eb;

}

#save_button {
    color: white;
    border-color: #028b40;
    background: #01b97c;
    width: 200px;
}

#settings_header {
    background: rgb(245, 105, 105);

}
"""

@spaces.GPU
def gen(prompt, negative, width, height):
    prior_output = prior(
        prompt=prompt,
        height=height,
        width=width,
        negative_prompt=negative,
        guidance_scale=4.0,
        num_images_per_prompt=num_images_per_prompt,
        num_inference_steps=20
    )
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings.half(),
        prompt=prompt,
        negative_prompt=negative,
        guidance_scale=0.0,
        output_type="pil",
        num_inference_steps=10
    ).images
    return decoder_output

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Stable Cascade ```DEMO```")
    with gr.Row():
        prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20)
        button = gr.Button(value="Generate", scale=1)
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            negative = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=2, lines=1, interactive=True)
        with gr.Row():
            width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
            height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
    with gr.Row():
        gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)

    button.click(gen, inputs=[prompt, negative, width, height], outputs=gallery)

demo.launch(show_api=False)