File size: 5,730 Bytes
bd7f3c4
 
b9a1bf8
cc231f9
b9a1bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd7f3c4
b9a1bf8
 
6352602
b9a1bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from diffusers import UNet2DConditionModel, DDIMInverseScheduler, DDIMScheduler
from utils.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
import torch
from PIL import Image
import argparse
import spaces

weak_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
strong_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")

def get_generator(random_seed):
    torch.manual_seed(int(random_seed))
    torch.cuda.manual_seed(int(random_seed))
    generator = torch.manual_seed(random_seed)
    return generator

model_dict = {
    "SDXL": None,
    "Human Preference": './ckpt/xlMoreArtFullV1.pREw.safetensors',
    'Batman': './ckpt/batman89000003.BlKn.safetensors',
    'Disney': './ckpt/princessXlV2.WSt4.safetensors',
    'Parchment': './ckpt/ParchartXL.safetensors'
}

# 生成图像的函数
@spaces.GPU(duration=240)
def generate_image(prompt, seed, T, high_cfg, low_cfg, high_lora, low_lora, weak_choice, strong_choice):
    # 设置随机种子
    size = 1024
    guidance_scale = 5.5
    lora_sclae = 0.8
    # device = 'cpu'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if high_lora == 0:
        high_lora = 0.001
    if low_lora == 0:
        low_lora = -0.001    #avoid bug

    # 选择模型
    model_id = "stabilityai/stable-diffusion-xl-base-1.0"
    dtype = torch.float16
    pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype,
                                                     variant='fp16',
                                                     safety_checker=None, requires_safety_checker=False).to(device)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.inv_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
                                                              subfolder='scheduler')

    # load dpo lora as strong model
    lora_name = strong_choice
    if model_dict[strong_choice] is not None:
        pipe.load_lora_weights(model_dict[strong_choice], adapter_name=lora_name)

    # weak model
    generator = get_generator(seed)
    pipe.disable_lora()
    image_sdxl = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale,
                      num_inference_steps=T, generator=generator).images[0]

    # strong model
    generator = get_generator(seed)
    if model_dict[lora_name] is not None:
        pipe.enable_lora()
        pipe.set_adapters(lora_name, adapter_weights=lora_sclae)
    image_dpo_lora = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale,
                          num_inference_steps=T, generator=generator).images[0]

    # W2SD
    generator = get_generator(seed)
    pipe.disable_lora()
    image_w2sd = \
    pipe.w2sd_lora(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale,
                   denoise_lora_scale=lora_sclae,
                   num_inference_steps=T, generator=generator,
                   lora_gap_list=[high_lora, low_lora],
                   cfg_gap_list=[high_cfg, low_cfg], lora_name=lora_name).images[0]

    return image_sdxl, image_dpo_lora, image_w2sd

with gr.Blocks() as app:
    gr.Markdown("# Weak-to-Strong Diffusion with Reflection")
    gr.Markdown("""
        **Note:**
        1. The weak model should not be too weak. It is recommended to set the weak LoRA scale to around (-0.5, 0.5), as otherwise, performance degradation may occur (refer to Figure 9 in the paper).
        2. Due to computational limits, it’s best to avoid setting Timesteps too high (standard is 50). A value of 10-15 is recommended, as higher values can slow down the process significantly.
        """)
    with gr.Row():
        weak_image = gr.Image(label="Generated Image by Weak Model", type="pil")
        strong_image = gr.Image(label="Generated Image by Strong Model", type="pil")
        w2sd_image = gr.Image(label="Generated Image via W2SD", type="pil")

    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", placeholder="A young girl holding a rose.", lines=2)

    with gr.Row():
        seed_slider = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Seed")
        T_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Timesteps")

    with gr.Row():
        high_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=0.8, label="Select Strong LoRA Scale")
        low_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=-0.5, label="Select Weak LoRA Scale")

        high_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=2.0, label="Select Strong Guidance Scale")
        low_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=1.0, label="Select Weak Guidance Scale")

    with gr.Row():
        weak_model_dropdown = gr.Dropdown(choices=["SDXL"], label="Select Weak Model",
                                          value="SDXL")
        strong_model_dropdown = gr.Dropdown(choices=model_dict.keys(),
                                            label="Select Strong Model", value="Human Preference")

    generate_button = gr.Button("Generate Image")
    generate_button.click(generate_image,
                          inputs=[prompt_input, seed_slider, T_slider, high_cfg_slider, low_cfg_slider, high_lora_slider, low_lora_slider, weak_model_dropdown,
                                  strong_model_dropdown],
                          outputs=[weak_image, strong_image, w2sd_image])

    # Enable the queue feature
    app.queue()

# app.launch()
app.launch(server_name='0.0.0.0', share=True, server_port=7788)