Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from PIL import Image, ImageFilter
|
3 |
import requests
|
4 |
from io import BytesIO
|
5 |
import torch
|
@@ -11,121 +11,152 @@ import cv2
|
|
11 |
# AI model repo for design generation
|
12 |
repo = "artificialguybr/TshirtDesignRedmond-V2"
|
13 |
|
14 |
-
# ---------- Step 1: Generate Plain T-shirt ----------
|
15 |
def generate_cloth(color_prompt):
|
16 |
-
prompt = f"A plain {color_prompt
|
17 |
api_url = f"https://api-inference.huggingface.co/models/{repo}"
|
|
|
18 |
payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
return Image.open(BytesIO(image_data)).convert("RGB")
|
25 |
-
except Exception as e:
|
26 |
-
print(f"Error generating cloth: {e}")
|
27 |
-
raise Exception("Failed to generate the T-shirt. Check your API key and internet connection.")
|
28 |
|
29 |
-
# ---------- Step 2: Generate Design ----------
|
30 |
def generate_design(design_prompt):
|
31 |
-
prompt = f"A bold {design_prompt
|
32 |
api_url = f"https://api-inference.huggingface.co/models/{repo}"
|
|
|
33 |
payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
return Image.open(BytesIO(image_data)).convert("RGBA")
|
40 |
-
except Exception as e:
|
41 |
-
print(f"Error generating design: {e}")
|
42 |
-
raise Exception("Failed to generate the design. Check your API key and internet connection.")
|
43 |
|
44 |
-
#
|
45 |
-
|
46 |
-
segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
|
47 |
-
except Exception as e:
|
48 |
-
print(f"Error loading segmentation model: {e}")
|
49 |
-
raise Exception("Failed to load the segmentation model. Check your PyTorch installation.")
|
50 |
|
|
|
51 |
def get_tshirt_mask(image):
|
52 |
-
|
53 |
-
|
|
|
54 |
T.ToTensor(),
|
55 |
-
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
56 |
])
|
57 |
-
input_tensor =
|
58 |
-
|
59 |
with torch.no_grad():
|
60 |
-
output = segmentation_model(input_tensor)[
|
61 |
|
|
|
62 |
mask = output.argmax(0).byte().cpu().numpy()
|
63 |
-
|
64 |
-
|
65 |
-
return
|
66 |
|
67 |
-
#
|
68 |
-
def
|
|
|
69 |
mask_np = np.array(mask)
|
70 |
-
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def overlay_design(cloth_image, design_image):
|
|
|
|
|
|
|
|
|
|
|
94 |
mask = get_tshirt_mask(cloth_image)
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
# ---------- Step 6: Final Design Pipeline ----------
|
99 |
def design_tshirt(color_prompt, design_prompt):
|
|
|
|
|
100 |
try:
|
101 |
-
|
102 |
-
|
|
|
103 |
final_image = overlay_design(cloth_image, design_image)
|
104 |
-
return
|
105 |
except Exception as e:
|
106 |
-
|
107 |
-
return None, None, None
|
108 |
|
109 |
-
#
|
110 |
with gr.Blocks() as interface:
|
111 |
-
gr.Markdown("# **AI
|
112 |
-
gr.Markdown("Generate
|
113 |
-
|
114 |
with gr.Row():
|
115 |
with gr.Column():
|
116 |
color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
|
117 |
design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
|
118 |
generate_button = gr.Button("Generate T-Shirt")
|
119 |
-
|
120 |
with gr.Column():
|
121 |
-
|
122 |
-
design_output = gr.Image(label="Generated Design")
|
123 |
-
final_output = gr.Image(label="Final Fitted T-Shirt Design")
|
124 |
|
125 |
generate_button.click(
|
126 |
design_tshirt,
|
127 |
inputs=[color_prompt, design_prompt],
|
128 |
-
outputs=
|
129 |
)
|
130 |
|
131 |
-
interface.launch(debug=True
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image, ImageDraw, ImageFilter
|
3 |
import requests
|
4 |
from io import BytesIO
|
5 |
import torch
|
|
|
11 |
# AI model repo for design generation
|
12 |
repo = "artificialguybr/TshirtDesignRedmond-V2"
|
13 |
|
|
|
14 |
def generate_cloth(color_prompt):
|
15 |
+
prompt = f"A plain {color_prompt} colored T-shirt hanging on a plain wall."
|
16 |
api_url = f"https://api-inference.huggingface.co/models/{repo}"
|
17 |
+
headers = {}
|
18 |
payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
|
19 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
20 |
+
if response.status_code == 200:
|
21 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
22 |
+
else:
|
23 |
+
raise Exception(f"Error generating cloth: {response.status_code}")
|
|
|
|
|
|
|
|
|
24 |
|
|
|
25 |
def generate_design(design_prompt):
|
26 |
+
prompt = f"A bold {design_prompt} design with vibrant colors, highly detailed."
|
27 |
api_url = f"https://api-inference.huggingface.co/models/{repo}"
|
28 |
+
headers = {}
|
29 |
payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
|
30 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
31 |
+
if response.status_code == 200:
|
32 |
+
return Image.open(BytesIO(response.content)).convert("RGBA")
|
33 |
+
else:
|
34 |
+
raise Exception(f"Error generating design: {response.status_code}")
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Load pretrained DeepLabV3 model for T-shirt segmentation
|
37 |
+
segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
# Apply segmentation to extract T-shirt mask
|
40 |
def get_tshirt_mask(image):
|
41 |
+
image = image.convert("RGB") # Ensure 3 channels
|
42 |
+
preprocess = T.Compose([
|
43 |
+
T.Resize((520, 520)), # Resize to avoid distortion
|
44 |
T.ToTensor(),
|
45 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
46 |
])
|
47 |
+
input_tensor = preprocess(image).unsqueeze(0)
|
|
|
48 |
with torch.no_grad():
|
49 |
+
output = segmentation_model(input_tensor)["out"][0]
|
50 |
|
51 |
+
# Extract T-shirt mask (class 15 in COCO dataset)
|
52 |
mask = output.argmax(0).byte().cpu().numpy()
|
53 |
+
raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
|
54 |
+
processed_mask = post_process_mask(raw_mask) # Apply post-processing
|
55 |
+
return processed_mask.resize(image.size)
|
56 |
|
57 |
+
# Post-process mask to improve quality
|
58 |
+
def post_process_mask(mask):
|
59 |
+
# Convert mask to NumPy array
|
60 |
mask_np = np.array(mask)
|
|
|
61 |
|
62 |
+
# Morphological operations to refine mask
|
63 |
+
kernel = np.ones((5, 5), np.uint8)
|
64 |
+
mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
|
65 |
+
mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
|
66 |
+
|
67 |
+
# Convert back to PIL image and smooth
|
68 |
+
processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
|
69 |
+
return processed_mask
|
70 |
+
|
71 |
+
# Get bounding box from mask
|
72 |
+
def get_bounding_box(mask):
|
73 |
+
mask_np = np.array(mask)
|
74 |
+
coords = np.column_stack(np.where(mask_np > 0))
|
75 |
+
if coords.size == 0:
|
76 |
+
raise Exception("No T-shirt detected in the image.")
|
77 |
+
x_min, y_min = coords.min(axis=0)
|
78 |
+
x_max, y_max = coords.max(axis=0)
|
79 |
+
return (x_min, y_min, x_max, y_max)
|
80 |
+
|
81 |
+
# Visualize mask and bounding box on the image for debugging
|
82 |
+
def visualize_mask(image, mask):
|
83 |
+
overlay = image.copy().convert("RGBA")
|
84 |
+
draw = ImageDraw.Draw(overlay)
|
85 |
+
bbox = get_bounding_box(mask)
|
86 |
+
draw.rectangle(bbox, outline="red", width=3) # Draw bounding box
|
87 |
+
blended = Image.blend(image.convert("RGBA"), overlay, alpha=0.5) # Overlay mask
|
88 |
+
blended.save("debug_visualization.png") # Save debug image
|
89 |
+
return blended
|
90 |
+
|
91 |
+
# Overlay design on the T-shirt
|
92 |
def overlay_design(cloth_image, design_image):
|
93 |
+
# Ensure images are in RGBA mode
|
94 |
+
cloth_image = cloth_image.convert("RGBA")
|
95 |
+
design_image = design_image.convert("RGBA")
|
96 |
+
|
97 |
+
# Generate T-shirt mask
|
98 |
mask = get_tshirt_mask(cloth_image)
|
99 |
+
|
100 |
+
# Extract bounding box for precise placement
|
101 |
+
bbox = get_bounding_box(mask)
|
102 |
+
tshirt_width = bbox[2] - bbox[0]
|
103 |
+
tshirt_height = bbox[3] - bbox[1]
|
104 |
+
|
105 |
+
# Resize the design to fit the T-shirt
|
106 |
+
design_width = int(tshirt_width * 0.6)
|
107 |
+
design_height = int(tshirt_height * 0.6)
|
108 |
+
resized_design = design_image.resize((design_width, design_height))
|
109 |
+
|
110 |
+
# Position the design in the center of the T-shirt
|
111 |
+
design_position = (
|
112 |
+
bbox[0] + (tshirt_width - design_width) // 2,
|
113 |
+
bbox[1] + (tshirt_height - design_height) // 2
|
114 |
+
)
|
115 |
+
|
116 |
+
# Create a transparent layer for the design
|
117 |
+
transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
|
118 |
+
transparent_layer.paste(resized_design, design_position, resized_design)
|
119 |
+
|
120 |
+
# Mask the design to the T-shirt area
|
121 |
+
masked_design = Image.composite(transparent_layer, Image.new("RGBA", cloth_image.size), mask)
|
122 |
+
|
123 |
+
# Combine the cloth image with the masked design
|
124 |
+
final_image = Image.alpha_composite(cloth_image, masked_design)
|
125 |
+
return final_image
|
126 |
+
|
127 |
+
def debug_intermediate_outputs(cloth_image, mask):
|
128 |
+
# Save debug images
|
129 |
+
cloth_image.save("debug_cloth_image.png")
|
130 |
+
mask.save("debug_tshirt_mask.png")
|
131 |
|
|
|
132 |
def design_tshirt(color_prompt, design_prompt):
|
133 |
+
cloth_image = generate_cloth(color_prompt)
|
134 |
+
design_image = generate_design(design_prompt)
|
135 |
try:
|
136 |
+
mask = get_tshirt_mask(cloth_image)
|
137 |
+
debug_intermediate_outputs(cloth_image, mask) # Debugging
|
138 |
+
visualize_mask(cloth_image, mask) # Save visualization
|
139 |
final_image = overlay_design(cloth_image, design_image)
|
140 |
+
return final_image
|
141 |
except Exception as e:
|
142 |
+
raise Exception(f"Error in design process: {str(e)}")
|
|
|
143 |
|
144 |
+
# Gradio UI
|
145 |
with gr.Blocks() as interface:
|
146 |
+
gr.Markdown("# **AI Cloth Designer**")
|
147 |
+
gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.")
|
|
|
148 |
with gr.Row():
|
149 |
with gr.Column():
|
150 |
color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
|
151 |
design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
|
152 |
generate_button = gr.Button("Generate T-Shirt")
|
|
|
153 |
with gr.Column():
|
154 |
+
output_image = gr.Image(label="Final T-Shirt Design")
|
|
|
|
|
155 |
|
156 |
generate_button.click(
|
157 |
design_tshirt,
|
158 |
inputs=[color_prompt, design_prompt],
|
159 |
+
outputs=output_image,
|
160 |
)
|
161 |
|
162 |
+
interface.launch(debug=True)
|