CreatiLayout / app.py
HuiZhang
Upload 8 files
be186ed verified
raw
history blame
7.86 kB
import gradio as gr
import torch
import spaces
from src.models.transformer_sd3_SiamLayout import SiamLayoutSD3Transformer2DModel
from src.pipeline.pipeline_CreatiLayout import CreatiLayoutSD3Pipeline
from utils.bbox_visualization import bbox_visualization,scale_boxes
from PIL import Image
import os
import pandas as pd
from huggingface_hub import login
hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret.")
login(token=hf_token)
model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
ckpt_path = "Benson1237/CreatiLayout"
transformer_additional_kwargs = dict(attention_type="layout",strict=True)
transformer = SiamLayoutSD3Transformer2DModel.from_pretrained(
ckpt_path, subfolder="transformer", torch_dtype=torch.float16,**transformer_additional_kwargs)
pipe = CreatiLayoutSD3Pipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
print("pipeline is loaded.")
@spaces.GPU
def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame, boxes:pd.DataFrame,seed: int=42, randomize_seed: bool=False, guidance_scale: float=7.5, num_inference_steps: int=50):
if randomize_seed:
seed = torch.randint(0, 100, (1,)).item()
height = 1024
width = 1024
box_detail_phrases_list_tmp = box_detail_phrases_list.values.tolist()
box_detail_phrases_list_tmp = [c[0] for c in box_detail_phrases_list_tmp]
boxes = boxes.astype(float).values.tolist()
white_image = Image.new('RGB', (width, height), color='rgb(256,256,256)')
show_input = {"boxes":scale_boxes(boxes,width,height),"labels":box_detail_phrases_list_tmp}
bbox_visualization_img = bbox_visualization(white_image,show_input)
result_img = pipe(
prompt=global_caption,
generator=torch.Generator(device="cuda").manual_seed(seed),
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
bbox_phrases=box_detail_phrases_list_tmp,
bbox_raw=boxes,
height=height,
width=width
).images[0]
return bbox_visualization_img, result_img
def get_samples():
sample_list = [
{
"global_caption": "A picturesque scene features Iron Man standing confidently on a rugged rock by the sea, holding a drawing board with his hands. The board displays the words 'Creative Layout' in a playful, hand-drawn font. The serene sea shimmers under the setting sun. The sky is painted with a gradient of warm colors, from deep oranges to soft purples.",
"region_caption_list": [
"Iron Man standing confidently on a rugged rock.",
"A rugged rock by the sea.",
"A drawing board with the words \"Creative Layout\" in a playful, hand-drawn font.",
"The serene sea shimmers under the setting sun.",
"The sky is a shade of deep orange to soft purple."
],
"region_bboxes_list": [
[0.40, 0.35, 0.55, 0.80],
[0.35, 0.75, 0.60, 0.95],
[0.40, 0.45, 0.55, 0.65],
[0.00, 0.30, 1.00, 0.90],
[0.00, 0.00, 1.00, 0.30]
]
},
{
"global_caption": "This is a photo showcasing two wooden benches in a park. The bench on the left is painted in a vibrant blue, while the one on the right is painted in a green. Both are placed on a path paved with stones, surrounded by lush trees and shrubs. The sunlight filters through the leaves, casting dappled shadows on the ground, creating a tranquil and comfortable atmosphere.",
"region_caption_list": [
"A weathered, blue wooden bench with green elements in a natural setting.",
"Old, weathered wooden benches with green and blue paint.",
"A dirt path in a park with green grass on the sides and two colorful wooden benches.",
"Thick, verdant foliage of mature trees in a dense forest."
],
"region_bboxes_list": [
[0.30, 0.44, 0.62, 0.78],
[0.54, 0.41, 0.75, 0.65],
[0.00, 0.39, 1.00, 1.00],
[0.00, 0.00, 1.00, 0.43]
]
},
{
"global_caption": "This is a wedding photo taken in a photography studio, showing a newlywed couple sitting on a brown leather sofa in a modern indoor setting. The groom is dressed in a pink suit, paired with a pink tie and white shirt, while the bride is wearing a white wedding dress with a long veil. They are sitting on a brown leather sofa, with a wooden table in front of them, on which a bouquet of flowers is placed. The background is a bar with a staircase and a wall decorated with lights, creating a warm and romantic atmosphere.",
"region_caption_list": [
"A floral arrangement consisting of roses, carnations, and eucalyptus leaves on a wooden surface.",
"A white wedding dress with off-the-shoulder ruffles and a long, sheer veil.",
"A polished wooden table with visible grain and knots.",
"A close-up of a dark brown leather sofa with tufted upholstery and button details.",
"A man in a pink suit with a white shirt and red tie, sitting on a leather armchair.",
"A person in a suit seated on a leather armchair near a wooden staircase with books and bottles.",
"Bride in white gown with veil, groom in maroon suit and pink tie, seated on leather armchairs."
],
"region_bboxes_list": [
[0.09, 0.65, 0.31, 0.93],
[0.62, 0.25, 0.89, 0.90],
[0.01, 0.70, 0.78, 0.99],
[0.76, 0.65, 1.00, 0.99],
[0.27, 0.32, 0.72, 0.75],
[0.00, 0.01, 0.52, 0.72],
[0.27, 0.09, 0.94, 0.89]
]
}
]
return [[sample["global_caption"], [[caption] for caption in sample["region_caption_list"]], sample["region_bboxes_list"]] for sample in sample_list]
with gr.Blocks() as demo:
gr.Markdown("# CreatiLayout / Layout-to-Image generation")
with gr.Row():
with gr.Column():
global_caption = gr.Textbox(lines=2, label="Global Caption")
box_detail_phrases_list = gr.Dataframe(headers=["Region Captions"], label="Region Captions")
boxes = gr.Dataframe(headers=["x1", "y1", "x2", "y2"], label="Region Bounding Boxes (x_min,y_min,x_max,y_max)")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(0, 100, step=1, label="Seed", value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
guidance_scale = gr.Slider(1, 30, step=0.5, label="Guidance Scale", value=7.5)
num_inference_steps = gr.Slider(1, 50, step=1, label="Number of inference steps", value=50)
with gr.Column():
bbox_visualization_img = gr.Image(type="pil", label="Bounding Box Visualization")
with gr.Column():
output_image = gr.Image(type="pil", label="Generated Image")
gr.Button("Generate").click(
fn=process_image_and_text,
inputs=[global_caption, box_detail_phrases_list, boxes, seed, randomize_seed, guidance_scale, num_inference_steps],
outputs=[bbox_visualization_img, output_image]
)
gr.Examples(
examples=get_samples(),
inputs=[global_caption, box_detail_phrases_list, boxes],
outputs=[bbox_visualization_img, output_image],
fn=process_image_and_text,
cache_examples=True
)
if __name__ == "__main__":
demo.launch()