Gopalag's picture
Update app.py
97a01d0 verified
raw
history blame
10.1 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import io
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype
).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
import numpy as np
from collections import Counter
def get_prominent_colors(image, num_colors=5):
"""
Get the most prominent colors from an image, focusing on edges
"""
# Convert to numpy array
img_array = np.array(image)
# Create a simple edge mask using gradient magnitude
gradient_x = np.gradient(img_array.mean(axis=2))[1]
gradient_y = np.gradient(img_array.mean(axis=2))[0]
gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
# Threshold to get edge pixels
edge_threshold = np.percentile(gradient_magnitude, 90) # Adjust percentile as needed
edge_mask = gradient_magnitude > edge_threshold
# Get colors from edge pixels
edge_colors = img_array[edge_mask]
# Convert colors to tuples for counting
colors = [tuple(color) for color in edge_colors]
# Count occurrences of each color
color_counts = Counter(colors)
# Get most common colors
prominent_colors = color_counts.most_common(num_colors)
return prominent_colors
def create_tshirt_preview(design_image, tshirt_color="white"):
"""
Overlay the design onto the existing t-shirt template and color match
"""
# Load the template t-shirt image
tshirt = Image.open('image.jpeg')
tshirt_width, tshirt_height = tshirt.size
# Convert design to PIL Image if it's not already
if not isinstance(design_image, Image.Image):
design_image = Image.fromarray(design_image)
# Get prominent colors from the design
prominent_colors = get_prominent_colors(design_image)
if prominent_colors:
# Use the most prominent color for the t-shirt
main_color = prominent_colors[0][0] # RGB tuple of most common color
else:
# Fallback to white if no colors found
main_color = (255, 255, 255)
# Convert design to PIL Image if it's not already
if not isinstance(design_image, Image.Image):
design_image = Image.fromarray(design_image)
# Resize design to fit nicely on shirt (40% of shirt width)
design_width = int(tshirt_width * 0.35) # Adjust this percentage as needed
design_height = int(design_width * design_image.size[1] / design_image.size[0])
design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
# Calculate position to center design on shirt
x = (tshirt_width - design_width) // 2
y = int(tshirt_height * 0.2) # Adjust this value based on your template
# If design has transparency (RGBA), create mask
if design_image.mode == 'RGBA':
mask = design_image.split()[3]
else:
mask = None
# Paste design onto shirt
tshirt.paste(design_image, (x, y), mask)
return tshirt
def enhance_prompt_for_tshirt(prompt, style=None):
"""Add specific terms to ensure good t-shirt designs."""
style_terms = {
"minimal": ["simple geometric shapes", "clean lines", "minimalist illustration"],
"vintage": ["distressed effect", "retro typography", "vintage illustration"],
"artistic": ["hand-drawn style", "watercolor effect", "artistic illustration"],
"geometric": ["abstract shapes", "geometric patterns", "modern design"],
"typography": ["bold typography", "creative lettering", "text-based design"],
"realistic": ["realistic", "cinematic", "photograph"]
}
base_terms = [
"create t-shirt design",
"with centered composition",
"high quality",
"professional design",
"clear background"
]
enhanced_prompt = f"{prompt}, {', '.join(base_terms)}"
if style and style in style_terms:
style_specific_terms = style_terms[style]
enhanced_prompt = f"{enhanced_prompt}, {', '.join(style_specific_terms)}"
return enhanced_prompt
@spaces.GPU()
def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False,
width=1024, height=1024, num_inference_steps=4,
progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
generator = torch.Generator().manual_seed(seed)
# Generate the design
design_image = pipe(
prompt=enhanced_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images[0]
# Create t-shirt preview
tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
return design_image, tshirt_preview, seed
# Available t-shirt colors
TSHIRT_COLORS = {
"White": "#FFFFFF",
"Black": "#000000",
"Navy": "#000080",
"Gray": "#808080"
}
examples = [
["Cool geometric mountain landscape", "minimal", "White"],
["Vintage motorcycle with flames", "vintage", "Black"],
["flamingo in scenic forset", "realistic", "White"],
["Adventure Starts typography", "typography", "White"]
]
styles = [
"minimal",
"vintage",
"artistic",
"geometric",
"typography",
"realistic"
]
css = """
#col-container {
margin: 0 auto;
max-width: 1200px !important;
padding: 20px;
}
.main-title {
text-align: center;
color: #2d3748;
margin-bottom: 1rem;
font-family: 'Poppins', sans-serif;
}
.subtitle {
text-align: center;
color: #4a5568;
margin-bottom: 2rem;
font-family: 'Inter', sans-serif;
font-size: 0.95rem;
line-height: 1.5;
}
.design-input {
border: 2px solid #e2e8f0;
border-radius: 10px;
padding: 12px !important;
margin-bottom: 1rem !important;
font-size: 1rem;
transition: all 0.3s ease;
}
.results-row {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
margin-top: 20px;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
# 👕Deradh's T-Shirt Design Generator
""",
elem_classes=["main-title"]
)
gr.Markdown(
"""
Create unique t-shirt designs using Deradh's AI.
Describe your design idea and select a style to generate professional-quality artwork
perfect for custom t-shirts.
""",
elem_classes=["subtitle"]
)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Text(
label="Design Description",
show_label=False,
max_lines=1,
placeholder="Describe your t-shirt design idea",
container=False,
elem_classes=["design-input"]
)
with gr.Column(scale=1):
style = gr.Dropdown(
choices=[""] + styles,
value="",
label="Style",
container=False
)
with gr.Column(scale=1):
tshirt_color = gr.Dropdown(
choices=list(TSHIRT_COLORS.keys()),
value="White",
label="T-Shirt Color",
container=False
)
run_button = gr.Button(
"✨ Generate",
scale=0,
elem_classes=["generate-button"]
)
with gr.Row(elem_classes=["results-row"]):
result = gr.Image(
label="Generated Design",
show_label=True,
elem_classes=["result-image"]
)
preview = gr.Image(
label="T-Shirt Preview",
show_label=True,
elem_classes=["preview-image"]
)
with gr.Accordion("🔧 Advanced Settings", open=False):
with gr.Group():
seed = gr.Slider(
label="Design Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize Design",
value=True
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
num_inference_steps = gr.Slider(
label="Generation Quality (Steps)",
minimum=1,
maximum=50,
step=1,
value=4,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, style, tshirt_color],
outputs=[result, preview, seed],
cache_examples=True
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
outputs=[result, preview, seed]
)
demo.launch()