File size: 3,608 Bytes
8b42472
a324e60
 
8b42472
 
a324e60
8b42472
a324e60
8b42472
 
 
a324e60
8b42472
 
 
 
 
 
a324e60
8b42472
 
 
 
 
 
 
 
a324e60
8b42472
 
 
 
 
 
 
 
 
eb8e92b
8b42472
 
 
 
 
 
 
 
a324e60
8b42472
a324e60
8b42472
 
 
 
 
a324e60
8b42472
 
 
 
 
 
a324e60
8b42472
a324e60
8b42472
 
 
 
a324e60
8b42472
a324e60
8b42472
 
 
a324e60
8b42472
 
 
 
83c7d4f
f447392
8b42472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a324e60
 
8b42472
 
a324e60
8b42472
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL

device = "cuda"

base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
refiner_id = "stabilityai/stable-diffusion-xl-refiner-1.0"

base_pipeline = DiffusionPipeline.from_pretrained(
    base_model_id, 
    torch_dtype = torch.float16, 
    variant = "fp16", 
    use_safetensors = True
).to(device)

refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    refiner_id, 
    text_encoder_2 = base_pipeline.text_encoder_2, 
    vae = vae,  
    torch_dtype = torch.float16, 
    variant = "fp16", 
    use_safetensors = True
).to(device)



SAMPLER_MAP = {
    "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
    "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
}



@spaces.GPU(duration=59)
def generate(
    prompt, 
    negative_prompt, 
    num_inference_steps, 
    denoising_switch, 
    width, height, 
    guidance_scale
):
    
    base_processed_image = base_pipeline(
        prompt = prompt, 
        negative_prompt = negative_prompt, 
        num_inference_steps = num_inference_steps, 
        denoising_end = denoising_switch, 
        width = width, 
        height = height, 
        guidance_scale = guidance_scale, 
        output_type = "latent"
    ).images

    generated_image = refiner(
        prompt = prompt, 
        negative_prompt = negative_prompt, 
        num_inference_steps = num_inference_steps, 
        denoising_start = denoising_switch, 
        width = width, 
        height = height, 
        guidance_scale = guidance_scale, 
        image = base_processed_image
    ).images[0]

    return generated_image


def create_ui():
    with gr.Blocks() as demo:
        with gr.Row():
            base_model = gr.Radio(label="Base model", choices=[base_model_id], value=base_model_id, interactive=False)
            refiner_model = gr.Radio(label="Refiner model", choices=[refiner_id], value=refiner_id, interactive=False)
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", lines=3)
                negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="low quality, bad quality")
                num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=30)
                denoising_switch = gr.Slider(label="Denoising Switch", minimum=0.01, maximum=1, step=0.01, value=0.8)
                width = gr.Slider(label="Width", minimum=64, maximum=2048, step=16, value=1024)
                height = gr.Slider(label="Height", minimum=64, maximum=2048, step=16, value=1024)
                guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.1, maximum=30, step=0.1, value=7.5)
            with gr.Column():
                output_image = gr.Image(interactive=False)
                generate_button = gr.Button("Run", variant="primary")
    
        generate_button.click(
            generate, 
            inputs=[
                prompt, 
                negative_prompt, 
                num_inference_steps, 
                denoising_switch, 
                width, height, 
                guidance_scale
            ], 
            outputs=[output_image]
        )

    return demo


if __name__ == "__main__":
    gradio_app = create_ui()
    gradio_app.launch(
        share = True
    )