|
import spaces |
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL |
|
|
|
device = "cuda" |
|
|
|
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) |
|
refiner_id = "stabilityai/stable-diffusion-xl-refiner-1.0" |
|
|
|
base_pipeline = DiffusionPipeline.from_pretrained( |
|
base_model_id, |
|
torch_dtype = torch.float16, |
|
variant = "fp16", |
|
use_safetensors = True |
|
).to(device) |
|
|
|
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
refiner_id, |
|
text_encoder_2 = base_pipeline.text_encoder_2, |
|
vae = vae, |
|
torch_dtype = torch.float16, |
|
variant = "fp16", |
|
use_safetensors = True |
|
).to(device) |
|
|
|
|
|
|
|
SAMPLER_MAP = { |
|
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), |
|
"Euler": lambda config: EulerDiscreteScheduler.from_config(config), |
|
} |
|
|
|
|
|
|
|
@spaces.GPU(duration=59) |
|
def generate( |
|
prompt, |
|
negative_prompt, |
|
num_inference_steps, |
|
denoising_switch, |
|
width, height, |
|
guidance_scale |
|
): |
|
|
|
base_processed_image = base_pipeline( |
|
prompt = prompt, |
|
negative_prompt = negative_prompt, |
|
num_inference_steps = num_inference_steps, |
|
denoising_end = denoising_switch, |
|
width = width, |
|
height = height, |
|
guidance_scale = guidance_scale, |
|
output_type = "latent" |
|
).images |
|
|
|
generated_image = refiner( |
|
prompt = prompt, |
|
negative_prompt = negative_prompt, |
|
num_inference_steps = num_inference_steps, |
|
denoising_start = denoising_switch, |
|
width = width, |
|
height = height, |
|
guidance_scale = guidance_scale, |
|
image = base_processed_image |
|
).images[0] |
|
|
|
return generated_image |
|
|
|
|
|
def create_ui(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
base_model = gr.Radio(label="Base model", choices=[base_model_id], value=base_model_id, interactive=False) |
|
refiner_model = gr.Radio(label="Refiner model", choices=[refiner_id], value=refiner_id, interactive=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt", lines=3) |
|
negative_prompt = gr.Textbox(label="Negative Prompt", lines=3) |
|
num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=30) |
|
denoising_switch = gr.Slider(label="Denoising Switch", minimum=0.01, maximum=1, step=0.01, value=0.8) |
|
width = gr.Slider(label="Width", minimum=64, maximum=2048, step=16, value=1024) |
|
height = gr.Slider(label="Height", minimum=64, maximum=2048, step=16, value=1024) |
|
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.1, maximum=30, step=0.1, value=7.5) |
|
with gr.Column(): |
|
output_image = gr.Image(interactive=False) |
|
generate_button = gr.Button("Run", variant="primary") |
|
|
|
generate_button.click( |
|
generate, |
|
inputs=[ |
|
prompt, |
|
negative_prompt, |
|
num_inference_steps, |
|
denoising_switch, |
|
width, height, |
|
guidance_scale |
|
], |
|
outputs=[output_image] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
gradio_app = create_ui() |
|
gradio_app.launch( |
|
share = True |
|
) |