|
import spaces |
|
import gradio as gr |
|
import torch |
|
from diffusers import UNet2DConditionModel, DDIMInverseScheduler, DDIMScheduler |
|
from utils.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline |
|
import torch |
|
from PIL import Image |
|
import argparse |
|
|
|
weak_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") |
|
strong_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") |
|
|
|
def get_generator(random_seed): |
|
torch.manual_seed(int(random_seed)) |
|
torch.cuda.manual_seed(int(random_seed)) |
|
generator = torch.manual_seed(random_seed) |
|
return generator |
|
|
|
model_dict = { |
|
"SDXL": None, |
|
"Human Preference": './ckpt/xlMoreArtFullV1.pREw.safetensors', |
|
'Batman': './ckpt/batman89000003.BlKn.safetensors', |
|
'Disney': './ckpt/princessXlV2.WSt4.safetensors', |
|
'Parchment': './ckpt/ParchartXL.safetensors' |
|
} |
|
|
|
|
|
@spaces.GPU(duration=240) |
|
def generate_image(prompt, seed, T, high_cfg, low_cfg, high_lora, low_lora, weak_choice, strong_choice): |
|
|
|
size = 1024 |
|
guidance_scale = 5.5 |
|
lora_sclae = 0.8 |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
if high_lora == 0: |
|
high_lora = 0.001 |
|
if low_lora == 0: |
|
low_lora = -0.001 |
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
dtype = torch.float16 |
|
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype, |
|
variant='fp16', |
|
safety_checker=None, requires_safety_checker=False).to(device) |
|
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
pipe.inv_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', |
|
subfolder='scheduler') |
|
|
|
|
|
lora_name = strong_choice |
|
if model_dict[strong_choice] is not None: |
|
pipe.load_lora_weights(model_dict[strong_choice], adapter_name=lora_name) |
|
|
|
|
|
generator = get_generator(seed) |
|
pipe.disable_lora() |
|
image_sdxl = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, |
|
num_inference_steps=T, generator=generator).images[0] |
|
|
|
|
|
generator = get_generator(seed) |
|
if model_dict[lora_name] is not None: |
|
pipe.enable_lora() |
|
pipe.set_adapters(lora_name, adapter_weights=lora_sclae) |
|
image_dpo_lora = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, |
|
num_inference_steps=T, generator=generator).images[0] |
|
|
|
|
|
generator = get_generator(seed) |
|
pipe.disable_lora() |
|
image_w2sd = \ |
|
pipe.w2sd_lora(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, |
|
denoise_lora_scale=lora_sclae, |
|
num_inference_steps=T, generator=generator, |
|
lora_gap_list=[high_lora, low_lora], |
|
cfg_gap_list=[high_cfg, low_cfg], lora_name=lora_name).images[0] |
|
|
|
return image_sdxl, image_dpo_lora, image_w2sd |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Weak-to-Strong Diffusion with Reflection") |
|
gr.Markdown(""" |
|
**Note:** |
|
1. The weak model should not be too weak. It is recommended to set the weak LoRA scale to around (-0.5, 0.5), as otherwise, performance degradation may occur (refer to Figure 9 in the paper). |
|
2. Due to computational limits, it’s best to avoid setting Timesteps too high (standard is 50). A value of 10-15 is recommended, as higher values can slow down the process significantly. |
|
""") |
|
with gr.Row(): |
|
weak_image = gr.Image(label="Generated Image by Weak Model", type="pil") |
|
strong_image = gr.Image(label="Generated Image by Strong Model", type="pil") |
|
w2sd_image = gr.Image(label="Generated Image via W2SD", type="pil") |
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox(label="Prompt", placeholder="A young girl holding a rose.", lines=2) |
|
|
|
with gr.Row(): |
|
seed_slider = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Seed") |
|
T_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Timesteps") |
|
|
|
with gr.Row(): |
|
high_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=0.8, label="Select Strong LoRA Scale") |
|
low_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=-0.5, label="Select Weak LoRA Scale") |
|
|
|
high_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=2.0, label="Select Strong Guidance Scale") |
|
low_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=1.0, label="Select Weak Guidance Scale") |
|
|
|
with gr.Row(): |
|
weak_model_dropdown = gr.Dropdown(choices=["SDXL"], label="Select Weak Model", |
|
value="SDXL") |
|
strong_model_dropdown = gr.Dropdown(choices=model_dict.keys(), |
|
label="Select Strong Model", value="Human Preference") |
|
|
|
generate_button = gr.Button("Generate Image") |
|
generate_button.click(generate_image, |
|
inputs=[prompt_input, seed_slider, T_slider, high_cfg_slider, low_cfg_slider, high_lora_slider, low_lora_slider, weak_model_dropdown, |
|
strong_model_dropdown], |
|
outputs=[weak_image, strong_image, w2sd_image]) |
|
|
|
|
|
app.queue() |
|
|
|
app.launch() |
|
|