Gopalag's picture
Update app.py
9693fed verified
raw
history blame
8.45 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import io
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype
).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
def create_tshirt_preview(design_image, tshirt_color="white"):
"""
Overlay the design onto a t-shirt template
"""
# Create a base t-shirt shape
tshirt_width = 800
tshirt_height = 1000
# Create base t-shirt image
tshirt = Image.new('RGB', (tshirt_width, tshirt_height), tshirt_color)
# Convert design to PIL Image if it's not already
if not isinstance(design_image, Image.Image):
design_image = Image.fromarray(design_image)
# Resize design to fit nicely on shirt (30% of shirt width)
design_width = int(tshirt_width * 0.3)
design_height = int(design_width * design_image.size[1] / design_image.size[0])
design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
# Calculate position to center design on shirt (top third of shirt)
x = (tshirt_width - design_width) // 2
y = int(tshirt_height * 0.25) # Position in top third
# If design has transparency (RGBA), create mask
if design_image.mode == 'RGBA':
mask = design_image.split()[3]
else:
mask = None
# Paste design onto shirt
tshirt.paste(design_image, (x, y), mask)
return tshirt
def enhance_prompt_for_tshirt(prompt, style=None):
"""Add specific terms to ensure good t-shirt designs."""
style_terms = {
"minimal": ["simple geometric shapes", "clean lines", "minimalist illustration"],
"vintage": ["distressed effect", "retro typography", "vintage illustration"],
"artistic": ["hand-drawn style", "watercolor effect", "artistic illustration"],
"geometric": ["abstract shapes", "geometric patterns", "modern design"],
"typography": ["bold typography", "creative lettering", "text-based design"]
}
base_terms = [
"create a t-shirt design",
"with centered composition",
"4k high quality",
"professional design",
"clear background"
]
enhanced_prompt = f"{prompt}, {', '.join(base_terms)}"
if style and style in style_terms:
style_specific_terms = style_terms[style]
enhanced_prompt = f"{enhanced_prompt}, {', '.join(style_specific_terms)}"
return enhanced_prompt
@spaces.GPU()
def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False,
width=1024, height=1024, num_inference_steps=4,
progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
generator = torch.Generator().manual_seed(seed)
# Generate the design
design_image = pipe(
prompt=enhanced_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
# Create t-shirt preview
tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
return design_image, tshirt_preview, seed
# Available t-shirt colors
TSHIRT_COLORS = {
"White": "#FFFFFF",
"Black": "#000000",
"Navy": "#000080",
"Gray": "#808080"
}
examples = [
["Cool geometric mountain landscape", "minimal", "White"],
["Vintage motorcycle with flames", "vintage", "Black"],
["Abstract watercolor butterfly in forest", "artistic", "White"],
["Adventure Awaits typography", "typography", "Gray"]
]
styles = [
"minimal",
"vintage",
"artistic",
"geometric",
"typography"
]
css = """
#col-container {
margin: 0 auto;
max-width: 1200px !important;
padding: 20px;
}
.main-title {
text-align: center;
color: #2d3748;
margin-bottom: 1rem;
font-family: 'Poppins', sans-serif;
}
.subtitle {
text-align: center;
color: #4a5568;
margin-bottom: 2rem;
font-family: 'Inter', sans-serif;
font-size: 0.95rem;
line-height: 1.5;
}
.design-input {
border: 2px solid #e2e8f0;
border-radius: 10px;
padding: 12px !important;
margin-bottom: 1rem !important;
font-size: 1rem;
transition: all 0.3s ease;
}
.results-row {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-top: 20px;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
# 👕 Deradh's T-Shirt Design Generator
""",
elem_classes=["main-title"]
)
gr.Markdown(
"""
Create unique t-shirt designs using AI.
Describe your design idea and select a style to generate professional-quality artwork
perfect for custom t-shirts.
""",
elem_classes=["subtitle"]
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Text(
label="Design Description",
show_label=False,
max_lines=1,
placeholder="Describe your t-shirt design idea",
container=False,
elem_classes=["design-input"]
)
with gr.Column(scale=1):
style = gr.Dropdown(
choices=[""] + styles,
value="",
label="Style",
container=False
)
with gr.Column(scale=1):
tshirt_color = gr.Dropdown(
choices=list(TSHIRT_COLORS.keys()),
value="White",
label="T-Shirt Color",
container=False
)
run_button = gr.Button(
"✨ Generate",
scale=0,
elem_classes=["generate-button"]
)
with gr.Row(elem_classes=["results-row"]):
result = gr.Image(
label="Generated Design",
show_label=True,
elem_classes=["result-image"]
)
preview = gr.Image(
label="T-Shirt Preview",
show_label=True,
elem_classes=["preview-image"]
)
with gr.Accordion("🔧 Advanced Settings", open=False):
with gr.Group():
seed = gr.Slider(
label="Design Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize Design",
value=True
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
num_inference_steps = gr.Slider(
label="Generation Quality (Steps)",
minimum=1,
maximum=50,
step=1,
value=4,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, style, tshirt_color],
outputs=[result, preview, seed],
cache_examples=True
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, preview, seed]
)
demo.launch()