Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,271 Bytes
8b1e96d 0cffd40 8b1e96d 0cffd40 8b1e96d 0cffd40 8b1e96d 0cccf69 8b1e96d ec35e66 8b1e96d 0cccf69 8b1e96d f286ae5 8b1e96d 6e61ba6 0cccf69 8b1e96d 0cffd40 8b1e96d 0cffd40 8b1e96d 0cffd40 8b1e96d 0cffd40 556fb50 8b1e96d 3eaeeea 8b1e96d 0cffd40 8b1e96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import gradio as gr
import torch
import PIL
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
}
loaded = None
CSS = """
.gradio-container {
max-width: 690px !important;
}
"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
pipe = DiffusionPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Function
@spaces.GPU()
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
loaded = num_inference_steps
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
return results.images[0]
# Gradio Interface
with gr.Blocks(css=CSS) as demo:
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='DMD2 Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch() |