CreatiLayout / app.py
HuiZhang0812's picture
Update app.py
25ebb7a verified
raw
history blame
8.56 kB
import gradio as gr
import torch
import spaces
from src.models.transformer_sd3_SiamLayout import SiamLayoutSD3Transformer2DModel
from src.pipeline.pipeline_CreatiLayout import CreatiLayoutSD3Pipeline
from utils.bbox_visualization import bbox_visualization,scale_boxes
from PIL import Image
import os
import pandas as pd
from huggingface_hub import login
hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret.")
login(token=hf_token)
model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
ckpt_path = "HuiZhang0812/CreatiLayout"
transformer_additional_kwargs = dict(attention_type="layout",strict=True)
transformer = SiamLayoutSD3Transformer2DModel.from_pretrained(
ckpt_path, subfolder="transformer", torch_dtype=torch.float16,**transformer_additional_kwargs)
pipe = CreatiLayoutSD3Pipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
print("pipeline is loaded.")
@spaces.GPU
def process_image_and_text(global_caption, box_detail_phrases_list:pd.DataFrame, boxes:pd.DataFrame,seed: int=42, randomize_seed: bool=False, guidance_scale: float=7.5, num_inference_steps: int=50):
if randomize_seed:
seed = torch.randint(0, 100, (1,)).item()
height = 1024
width = 1024
box_detail_phrases_list_tmp = box_detail_phrases_list.values.tolist()
box_detail_phrases_list_tmp = [c[0] for c in box_detail_phrases_list_tmp]
boxes = boxes.astype(float).values.tolist()
white_image = Image.new('RGB', (width, height), color='rgb(256,256,256)')
show_input = {"boxes":scale_boxes(boxes,width,height),"labels":box_detail_phrases_list_tmp}
bbox_visualization_img = bbox_visualization(white_image,show_input)
result_img = pipe(
prompt=global_caption,
generator=torch.Generator(device="cuda").manual_seed(seed),
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
bbox_phrases=box_detail_phrases_list_tmp,
bbox_raw=boxes,
height=height,
width=width
).images[0]
return bbox_visualization_img, result_img
def get_samples():
sample_list = [
{
"global_caption": "A picturesque scene features Iron Man standing confidently on a rugged rock by the sea, holding a drawing board with his hands. The board displays the words 'Creative Layout' in a playful, hand-drawn font. The serene sea shimmers under the setting sun. The sky is painted with a gradient of warm colors, from deep oranges to soft purples.",
"region_caption_list": [
"Iron Man standing confidently on a rugged rock.",
"A rugged rock by the sea.",
"A drawing board with the words \"Creative Layout\" in a playful, hand-drawn font.",
"The serene sea shimmers under the setting sun.",
"The sky is a shade of deep orange to soft purple."
],
"region_bboxes_list": [
[0.40, 0.35, 0.55, 0.80],
[0.35, 0.75, 0.60, 0.95],
[0.40, 0.45, 0.55, 0.65],
[0.00, 0.30, 1.00, 0.90],
[0.00, 0.00, 1.00, 0.30]
]
},
{
"global_caption": "This is a photo showcasing two wooden benches in a park. The bench on the left is painted in a vibrant blue, while the one on the right is painted in a green. Both are placed on a path paved with stones, surrounded by lush trees and shrubs. The sunlight filters through the leaves, casting dappled shadows on the ground, creating a tranquil and comfortable atmosphere.",
"region_caption_list": [
"A weathered, blue wooden bench with green elements in a natural setting.",
"Old, weathered wooden benches with green paint.",
"A dirt path in a park with green grass on the sides and two colorful wooden benches.",
"Thick, verdant foliage of mature trees in a dense forest."
],
"region_bboxes_list": [
[0.30, 0.44, 0.62, 0.78],
[0.54, 0.41, 0.75, 0.65],
[0.00, 0.39, 1.00, 1.00],
[0.00, 0.00, 1.00, 0.43]
]
},
{
"global_caption": "This is a wedding photo taken in a photography studio, showing a newlywed couple sitting on a brown leather sofa in a modern indoor setting. The groom is dressed in a pink suit, paired with a pink tie and white shirt, while the bride is wearing a white wedding dress with a long veil. They are sitting on a brown leather sofa, with a wooden table in front of them, on which a bouquet of flowers is placed. The background is a bar with a staircase and a wall decorated with lights, creating a warm and romantic atmosphere.",
"region_caption_list": [
"A floral arrangement consisting of roses, carnations, and eucalyptus leaves on a wooden surface.",
"A white wedding dress with off-the-shoulder ruffles and a long, sheer veil.",
"A polished wooden table with visible grain and knots.",
"A close-up of a dark brown leather sofa with tufted upholstery and button details.",
"A man in a pink suit with a white shirt and red tie, sitting on a leather armchair.",
"A person in a suit seated on a leather armchair near a wooden staircase with books and bottles.",
"Bride in white gown with veil, groom in maroon suit and pink tie, seated on leather armchairs."
],
"region_bboxes_list": [
[0.09, 0.65, 0.31, 0.93],
[0.62, 0.25, 0.89, 0.90],
[0.01, 0.70, 0.78, 0.99],
[0.76, 0.65, 1.00, 0.99],
[0.27, 0.32, 0.72, 0.75],
[0.00, 0.01, 0.52, 0.72],
[0.27, 0.09, 0.94, 0.89]
]
}
]
return [[sample["global_caption"], [[caption] for caption in sample["region_caption_list"]], sample["region_bboxes_list"]] for sample in sample_list]
with gr.Blocks() as demo:
gr.Markdown("# CreatiLayout: Layout-to-Image generation")
gr.Markdown("""CreatiLayout is a layout-to-image framework for Diffusion Transformer models, offering high-quality and fine-grained controllable generation based on the global description and entity annotations. Users need to provide a global description and the position and description of each entity, as shown in the examples. Please feel free to modify the position and attributes of the entities in the examples (such as size, color, shape, text, portrait, etc.). Here are some inspirations: Iron Man -> Spider Man/Harry Potter/Buzz Lightyear CreatiLayout -> Hello Friends/Let's Control drawing board -> round drawing board Modify the position of the drawing board to (0.4, 0.15, 0.55, 0.35)""")
with gr.Row():
with gr.Column():
global_caption = gr.Textbox(lines=2, label="Global Caption")
box_detail_phrases_list = gr.Dataframe(headers=["Region Captions"], label="Region Captions")
boxes = gr.Dataframe(headers=["x1", "y1", "x2", "y2"], label="Region Bounding Boxes (x_min,y_min,x_max,y_max)")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(0, 100, step=1, label="Seed", value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
guidance_scale = gr.Slider(1, 30, step=0.5, label="Guidance Scale", value=7.5)
num_inference_steps = gr.Slider(1, 50, step=1, label="Number of inference steps", value=50)
with gr.Column():
bbox_visualization_img = gr.Image(type="pil", label="Bounding Box Visualization")
with gr.Column():
output_image = gr.Image(type="pil", label="Generated Image")
gr.Button("Generate").click(
fn=process_image_and_text,
inputs=[global_caption, box_detail_phrases_list, boxes, seed, randomize_seed, guidance_scale, num_inference_steps],
outputs=[bbox_visualization_img, output_image]
)
gr.Examples(
examples=get_samples(),
inputs=[global_caption, box_detail_phrases_list, boxes],
outputs=[bbox_visualization_img, output_image],
fn=process_image_and_text,
cache_examples=True
)
if __name__ == "__main__":
demo.launch()