import spaces import gradio as gr import torch from diffusers import UNet2DConditionModel, DDIMInverseScheduler, DDIMScheduler from utils.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline import torch from PIL import Image import argparse weak_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") strong_model = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") def get_generator(random_seed): torch.manual_seed(int(random_seed)) torch.cuda.manual_seed(int(random_seed)) generator = torch.manual_seed(random_seed) return generator model_dict = { "SDXL": None, "Human Preference": './ckpt/xlMoreArtFullV1.pREw.safetensors', 'Batman': './ckpt/batman89000003.BlKn.safetensors', 'Disney': './ckpt/princessXlV2.WSt4.safetensors', 'Parchment': './ckpt/ParchartXL.safetensors' } # 生成图像的函数 @spaces.GPU(duration=240) def generate_image(prompt, seed, T, high_cfg, low_cfg, high_lora, low_lora, weak_choice, strong_choice): # 设置随机种子 size = 1024 guidance_scale = 5.5 lora_sclae = 0.8 # device = 'cpu' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if high_lora == 0: high_lora = 0.001 if low_lora == 0: low_lora = -0.001 #avoid bug # 选择模型 model_id = "stabilityai/stable-diffusion-xl-base-1.0" dtype = torch.float16 pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=dtype, variant='fp16', safety_checker=None, requires_safety_checker=False).to(device) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.inv_scheduler = DDIMInverseScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='scheduler') # load dpo lora as strong model lora_name = strong_choice if model_dict[strong_choice] is not None: pipe.load_lora_weights(model_dict[strong_choice], adapter_name=lora_name) # weak model generator = get_generator(seed) pipe.disable_lora() image_sdxl = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, num_inference_steps=T, generator=generator).images[0] # strong model generator = get_generator(seed) if model_dict[lora_name] is not None: pipe.enable_lora() pipe.set_adapters(lora_name, adapter_weights=lora_sclae) image_dpo_lora = pipe(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, num_inference_steps=T, generator=generator).images[0] # W2SD generator = get_generator(seed) pipe.disable_lora() image_w2sd = \ pipe.w2sd_lora(prompt=prompt, height=size, width=size, guidance_scale=guidance_scale, denoise_lora_scale=lora_sclae, num_inference_steps=T, generator=generator, lora_gap_list=[high_lora, low_lora], cfg_gap_list=[high_cfg, low_cfg], lora_name=lora_name).images[0] return image_sdxl, image_dpo_lora, image_w2sd with gr.Blocks() as app: gr.Markdown("# Weak-to-Strong Diffusion with Reflection") gr.Markdown(""" **Note:** 1. The weak model should not be too weak. It is recommended to set the weak LoRA scale to around (-0.5, 0.5), as otherwise, performance degradation may occur (refer to Figure 9 in the paper). 2. Due to computational limits, it’s best to avoid setting Timesteps too high (standard is 50). A value of 10-15 is recommended, as higher values can slow down the process significantly. """) with gr.Row(): weak_image = gr.Image(label="Generated Image by Weak Model", type="pil") strong_image = gr.Image(label="Generated Image by Strong Model", type="pil") w2sd_image = gr.Image(label="Generated Image via W2SD", type="pil") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="A young girl holding a rose.", lines=2) with gr.Row(): seed_slider = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Seed") T_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Timesteps") with gr.Row(): high_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=0.8, label="Select Strong LoRA Scale") low_lora_slider = gr.Slider(minimum=-2.0, maximum=2.0, step=0.1, value=-0.5, label="Select Weak LoRA Scale") high_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=2.0, label="Select Strong Guidance Scale") low_cfg_slider = gr.Slider(minimum=-3, maximum=3, step=0.1, value=1.0, label="Select Weak Guidance Scale") with gr.Row(): weak_model_dropdown = gr.Dropdown(choices=["SDXL"], label="Select Weak Model", value="SDXL") strong_model_dropdown = gr.Dropdown(choices=model_dict.keys(), label="Select Strong Model", value="Human Preference") generate_button = gr.Button("Generate Image") generate_button.click(generate_image, inputs=[prompt_input, seed_slider, T_slider, high_cfg_slider, low_cfg_slider, high_lora_slider, low_lora_slider, weak_model_dropdown, strong_model_dropdown], outputs=[weak_image, strong_image, w2sd_image]) # Enable the queue feature app.queue() # app.launch() app.launch(server_name='0.0.0.0', share=True, server_port=7788)