Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from diffusers import UNet2DConditionModel, DDIMInverseScheduler, DDIMScheduler | |
from utils.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline | |
import torch | |
from PIL import Image | |
import argparse | |
import spaces | |
weak_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") | |
strong_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") | |
def get_generator(random_seed): | |
torch.manual_seed(int(random_seed)) | |
torch.cuda.manual_seed(int(random_seed)) | |
generator = torch.manual_seed(random_seed) | |
return generator | |
model_dict = { | |
"SDXL": None, | |
"Human Preference": './ckpt/xlMoreArtFullV1.pREw.safetensors', | |
'Batman': './ckpt/batman89000003.BlKn.safetensors', | |
'Disney': './ckpt/princessXlV2.WSt4.safetensors', | |
'Parchment': './ckpt/ParchartXL.safetensors' | |
} | |
# 生成图像的函数 | |
def generate_image(prompt, seed, T, high_cfg, low_cfg, high_lora, low_lora, weak_choice, strong_choice): | |
# 设置随机种子 | |
size = 1024 | |
guidance_scale = 5.5 | |
lora_sclae = 0.8 | |
# device = 'cpu' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if high_lora == 0: | |
high_lora = 0.001 | |
if low_lora == 0: | |
low_lora = -0.001 #avoid bug | |
# 选择模型 | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
dtype = torch.float16 | |
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype, | |
variant='fp16', | |
safety_checker=None, requires_safety_checker=False).to(device) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
pipe.inv_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', | |
subfolder='scheduler') | |
# load dpo lora as strong model | |
lora_name = strong_choice | |
if model_dict[strong_choice] is not None: | |
pipe.load_lora_weights(model_dict[strong_choice], adapter_name=lora_name) | |
# weak model | |
generator = get_generator(seed) | |
pipe.disable_lora() | |
image_sdxl = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, | |
num_inference_steps=T, generator=generator).images[0] | |
# strong model | |
generator = get_generator(seed) | |
if model_dict[lora_name] is not None: | |
pipe.enable_lora() | |
pipe.set_adapters(lora_name, adapter_weights=lora_sclae) | |
image_dpo_lora = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, | |
num_inference_steps=T, generator=generator).images[0] | |
# W2SD | |
generator = get_generator(seed) | |
pipe.disable_lora() | |
image_w2sd = \ | |
pipe.w2sd_lora(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, | |
denoise_lora_scale=lora_sclae, | |
num_inference_steps=T, generator=generator, | |
lora_gap_list=[high_lora, low_lora], | |
cfg_gap_list=[high_cfg, low_cfg], lora_name=lora_name).images[0] | |
return image_sdxl, image_dpo_lora, image_w2sd | |
with gr.Blocks() as app: | |
gr.Markdown("# Weak-to-Strong Diffusion with Reflection") | |
gr.Markdown(""" | |
**Note:** | |
1. The weak model should not be too weak. It is recommended to set the weak LoRA scale to around (-0.5, 0.5), as otherwise, performance degradation may occur (refer to Figure 9 in the paper). | |
2. Due to computational limits, it’s best to avoid setting Timesteps too high (standard is 50). A value of 10-15 is recommended, as higher values can slow down the process significantly. | |
""") | |
with gr.Row(): | |
weak_image = gr.Image(label="Generated Image by Weak Model", type="pil") | |
strong_image = gr.Image(label="Generated Image by Strong Model", type="pil") | |
w2sd_image = gr.Image(label="Generated Image via W2SD", type="pil") | |
with gr.Row(): | |
prompt_input = gr.Textbox(label="Prompt", placeholder="A young girl holding a rose.", lines=2) | |
with gr.Row(): | |
seed_slider = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Seed") | |
T_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Timesteps") | |
with gr.Row(): | |
high_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=0.8, label="Select Strong LoRA Scale") | |
low_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=-0.5, label="Select Weak LoRA Scale") | |
high_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=2.0, label="Select Strong Guidance Scale") | |
low_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=1.0, label="Select Weak Guidance Scale") | |
with gr.Row(): | |
weak_model_dropdown = gr.Dropdown(choices=["SDXL"], label="Select Weak Model", | |
value="SDXL") | |
strong_model_dropdown = gr.Dropdown(choices=model_dict.keys(), | |
label="Select Strong Model", value="Human Preference") | |
generate_button = gr.Button("Generate Image") | |
generate_button.click(generate_image, | |
inputs=[prompt_input, seed_slider, T_slider, high_cfg_slider, low_cfg_slider, high_lora_slider, low_lora_slider, weak_model_dropdown, | |
strong_model_dropdown], | |
outputs=[weak_image, strong_image, w2sd_image]) | |
# Enable the queue feature | |
app.queue() | |
# app.launch() | |
app.launch(server_name='0.0.0.0', share=True, server_port=7788) |