Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
from src.models.transformer_sd3_SiamLayout import SiamLayoutSD3Transformer2DModel | |
from src.pipeline.pipeline_CreatiLayout import CreatiLayoutSD3Pipeline | |
from utils.bbox_visualization import bbox_visualization,scale_boxes | |
from PIL import Image | |
import os | |
import pandas as pd | |
from huggingface_hub import login | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token is None: | |
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret.") | |
login(token=hf_token) | |
model_path = "stabilityai/stable-diffusion-3-medium-diffusers" | |
ckpt_path = "Benson1237/CreatiLayout" | |
transformer_additional_kwargs = dict(attention_type="layout",strict=True) | |
transformer = SiamLayoutSD3Transformer2DModel.from_pretrained( | |
ckpt_path, subfolder="transformer", torch_dtype=torch.float16,**transformer_additional_kwargs) | |
pipe = CreatiLayoutSD3Pipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
print("pipeline is loaded.") | |
def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame, boxes:pd.DataFrame,seed: int=42, randomize_seed: bool=False, guidance_scale: float=7.5, num_inference_steps: int=50): | |
if randomize_seed: | |
seed = torch.randint(0, 100, (1,)).item() | |
height = 1024 | |
width = 1024 | |
box_detail_phrases_list_tmp = box_detail_phrases_list.values.tolist() | |
box_detail_phrases_list_tmp = [c[0] for c in box_detail_phrases_list_tmp] | |
boxes = boxes.astype(float).values.tolist() | |
white_image = Image.new('RGB', (width, height), color='rgb(256,256,256)') | |
show_input = {"boxes":scale_boxes(boxes,width,height),"labels":box_detail_phrases_list_tmp} | |
bbox_visualization_img = bbox_visualization(white_image,show_input) | |
result_img = pipe( | |
prompt=global_caption, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
bbox_phrases=box_detail_phrases_list_tmp, | |
bbox_raw=boxes, | |
height=height, | |
width=width | |
).images[0] | |
return bbox_visualization_img, result_img | |
def get_samples(): | |
sample_list = [ | |
{ | |
"global_caption": "A picturesque scene features Iron Man standing confidently on a rugged rock by the sea, holding a drawing board with his hands. The board displays the words 'Creative Layout' in a playful, hand-drawn font. The serene sea shimmers under the setting sun. The sky is painted with a gradient of warm colors, from deep oranges to soft purples.", | |
"region_caption_list": [ | |
"Iron Man standing confidently on a rugged rock.", | |
"A rugged rock by the sea.", | |
"A drawing board with the words \"Creative Layout\" in a playful, hand-drawn font.", | |
"The serene sea shimmers under the setting sun.", | |
"The sky is a shade of deep orange to soft purple." | |
], | |
"region_bboxes_list": [ | |
[0.40, 0.35, 0.55, 0.80], | |
[0.35, 0.75, 0.60, 0.95], | |
[0.40, 0.45, 0.55, 0.65], | |
[0.00, 0.30, 1.00, 0.90], | |
[0.00, 0.00, 1.00, 0.30] | |
] | |
}, | |
{ | |
"global_caption": "This is a photo showcasing two wooden benches in a park. The bench on the left is painted in a vibrant blue, while the one on the right is painted in a green. Both are placed on a path paved with stones, surrounded by lush trees and shrubs. The sunlight filters through the leaves, casting dappled shadows on the ground, creating a tranquil and comfortable atmosphere.", | |
"region_caption_list": [ | |
"A weathered, blue wooden bench with green elements in a natural setting.", | |
"Old, weathered wooden benches with green and blue paint.", | |
"A dirt path in a park with green grass on the sides and two colorful wooden benches.", | |
"Thick, verdant foliage of mature trees in a dense forest." | |
], | |
"region_bboxes_list": [ | |
[0.30, 0.44, 0.62, 0.78], | |
[0.54, 0.41, 0.75, 0.65], | |
[0.00, 0.39, 1.00, 1.00], | |
[0.00, 0.00, 1.00, 0.43] | |
] | |
}, | |
{ | |
"global_caption": "This is a wedding photo taken in a photography studio, showing a newlywed couple sitting on a brown leather sofa in a modern indoor setting. The groom is dressed in a pink suit, paired with a pink tie and white shirt, while the bride is wearing a white wedding dress with a long veil. They are sitting on a brown leather sofa, with a wooden table in front of them, on which a bouquet of flowers is placed. The background is a bar with a staircase and a wall decorated with lights, creating a warm and romantic atmosphere.", | |
"region_caption_list": [ | |
"A floral arrangement consisting of roses, carnations, and eucalyptus leaves on a wooden surface.", | |
"A white wedding dress with off-the-shoulder ruffles and a long, sheer veil.", | |
"A polished wooden table with visible grain and knots.", | |
"A close-up of a dark brown leather sofa with tufted upholstery and button details.", | |
"A man in a pink suit with a white shirt and red tie, sitting on a leather armchair.", | |
"A person in a suit seated on a leather armchair near a wooden staircase with books and bottles.", | |
"Bride in white gown with veil, groom in maroon suit and pink tie, seated on leather armchairs." | |
], | |
"region_bboxes_list": [ | |
[0.09, 0.65, 0.31, 0.93], | |
[0.62, 0.25, 0.89, 0.90], | |
[0.01, 0.70, 0.78, 0.99], | |
[0.76, 0.65, 1.00, 0.99], | |
[0.27, 0.32, 0.72, 0.75], | |
[0.00, 0.01, 0.52, 0.72], | |
[0.27, 0.09, 0.94, 0.89] | |
] | |
} | |
] | |
return [[sample["global_caption"], [[caption] for caption in sample["region_caption_list"]], sample["region_bboxes_list"]] for sample in sample_list] | |
with gr.Blocks() as demo: | |
gr.Markdown("# CreatiLayout / Layout-to-Image generation") | |
with gr.Row(): | |
with gr.Column(): | |
global_caption = gr.Textbox(lines=2, label="Global Caption") | |
box_detail_phrases_list = gr.Dataframe(headers=["Region Captions"], label="Region Captions") | |
boxes = gr.Dataframe(headers=["x1", "y1", "x2", "y2"], label="Region Bounding Boxes (x_min,y_min,x_max,y_max)") | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(0, 100, step=1, label="Seed", value=42) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
guidance_scale = gr.Slider(1, 30, step=0.5, label="Guidance Scale", value=7.5) | |
num_inference_steps = gr.Slider(1, 50, step=1, label="Number of inference steps", value=50) | |
with gr.Column(): | |
bbox_visualization_img = gr.Image(type="pil", label="Bounding Box Visualization") | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label="Generated Image") | |
gr.Button("Generate").click( | |
fn=process_image_and_text, | |
inputs=[global_caption, box_detail_phrases_list, boxes, seed, randomize_seed, guidance_scale, num_inference_steps], | |
outputs=[bbox_visualization_img, output_image] | |
) | |
gr.Examples( | |
examples=get_samples(), | |
inputs=[global_caption, box_detail_phrases_list, boxes], | |
outputs=[bbox_visualization_img, output_image], | |
fn=process_image_and_text, | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch() | |