Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ from torchvision import models
|
|
8 |
import numpy as np
|
9 |
import cv2
|
10 |
|
11 |
-
# AI model repo for design generation
|
12 |
repo = "artificialguybr/TshirtDesignRedmond-V2"
|
13 |
|
14 |
def generate_cloth(color_prompt):
|
@@ -33,10 +32,8 @@ def generate_design(design_prompt):
|
|
33 |
else:
|
34 |
raise Exception(f"Error generating design: {response.status_code}")
|
35 |
|
36 |
-
# Load pretrained DeepLabV3 model for T-shirt segmentation
|
37 |
segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
|
38 |
|
39 |
-
# Apply segmentation to extract T-shirt mask
|
40 |
def get_tshirt_mask(image):
|
41 |
image = image.convert("RGB") # Ensure 3 channels
|
42 |
preprocess = T.Compose([
|
@@ -48,27 +45,21 @@ def get_tshirt_mask(image):
|
|
48 |
with torch.no_grad():
|
49 |
output = segmentation_model(input_tensor)["out"][0]
|
50 |
|
51 |
-
# Extract T-shirt mask (class 15 in COCO dataset)
|
52 |
mask = output.argmax(0).byte().cpu().numpy()
|
53 |
raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
|
54 |
processed_mask = post_process_mask(raw_mask) # Apply post-processing
|
55 |
return processed_mask.resize(image.size)
|
56 |
|
57 |
-
# Post-process mask to improve quality
|
58 |
def post_process_mask(mask):
|
59 |
-
# Convert mask to NumPy array
|
60 |
mask_np = np.array(mask)
|
61 |
|
62 |
-
# Morphological operations to refine mask
|
63 |
kernel = np.ones((5, 5), np.uint8)
|
64 |
mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
|
65 |
mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
|
66 |
|
67 |
-
# Convert back to PIL image and smooth
|
68 |
processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
|
69 |
return processed_mask
|
70 |
|
71 |
-
# Get bounding box from mask
|
72 |
def get_bounding_box(mask):
|
73 |
mask_np = np.array(mask)
|
74 |
coords = np.column_stack(np.where(mask_np > 0))
|
@@ -78,7 +69,6 @@ def get_bounding_box(mask):
|
|
78 |
x_max, y_max = coords.max(axis=0)
|
79 |
return (x_min, y_min, x_max, y_max)
|
80 |
|
81 |
-
# Visualize mask and bounding box on the image for debugging
|
82 |
def visualize_mask(image, mask):
|
83 |
overlay = image.copy().convert("RGBA")
|
84 |
draw = ImageDraw.Draw(overlay)
|
@@ -89,38 +79,30 @@ def visualize_mask(image, mask):
|
|
89 |
return blended
|
90 |
|
91 |
def overlay_design(cloth_image, design_image):
|
92 |
-
# Ensure images are in RGBA mode
|
93 |
cloth_image = cloth_image.convert("RGBA")
|
94 |
design_image = design_image.convert("RGBA")
|
95 |
|
96 |
-
# Generate T-shirt mask
|
97 |
mask = get_tshirt_mask(cloth_image)
|
98 |
|
99 |
-
# Extract bounding box and T-shirt coordinates
|
100 |
bbox = get_bounding_box(mask)
|
101 |
tshirt_width = bbox[2] - bbox[0]
|
102 |
tshirt_height = bbox[3] - bbox[1]
|
103 |
|
104 |
-
# Resize the design to fit within 70% of the T-shirt bounding box
|
105 |
scale_factor = 0.7
|
106 |
design_width = int(tshirt_width * scale_factor)
|
107 |
design_height = int(tshirt_height * scale_factor)
|
108 |
-
design_image = design_image.resize((design_width, design_height), Image.
|
109 |
|
110 |
-
# Calculate position to center the design on the T-shirt
|
111 |
x_offset = bbox[0] + (tshirt_width - design_width) // 2
|
112 |
y_offset = bbox[1] + (tshirt_height - design_height) // 2
|
113 |
|
114 |
-
# Create a transparent layer for the design
|
115 |
transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
|
116 |
transparent_layer.paste(design_image, (x_offset, y_offset), design_image)
|
117 |
|
118 |
-
# Blend the T-shirt with the design
|
119 |
final_image = Image.alpha_composite(cloth_image, transparent_layer)
|
120 |
return final_image
|
121 |
|
122 |
def debug_intermediate_outputs(cloth_image, mask):
|
123 |
-
# Save debug images
|
124 |
cloth_image.save("debug_cloth_image.png")
|
125 |
mask.save("debug_tshirt_mask.png")
|
126 |
|
@@ -136,17 +118,16 @@ def design_tshirt(color_prompt, design_prompt):
|
|
136 |
except Exception as e:
|
137 |
raise Exception(f"Error in design process: {str(e)}")
|
138 |
|
139 |
-
# Gradio UI
|
140 |
with gr.Blocks() as interface:
|
141 |
gr.Markdown("# **AI Cloth Designer**")
|
142 |
gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.")
|
143 |
with gr.Row():
|
144 |
-
with gr.Column(
|
145 |
color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
|
146 |
design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
|
147 |
generate_button = gr.Button("Generate T-Shirt")
|
148 |
-
with gr.Column(
|
149 |
-
output_image = gr.Image(label="Final T-Shirt Design"
|
150 |
|
151 |
generate_button.click(
|
152 |
design_tshirt,
|
|
|
8 |
import numpy as np
|
9 |
import cv2
|
10 |
|
|
|
11 |
repo = "artificialguybr/TshirtDesignRedmond-V2"
|
12 |
|
13 |
def generate_cloth(color_prompt):
|
|
|
32 |
else:
|
33 |
raise Exception(f"Error generating design: {response.status_code}")
|
34 |
|
|
|
35 |
segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
|
36 |
|
|
|
37 |
def get_tshirt_mask(image):
|
38 |
image = image.convert("RGB") # Ensure 3 channels
|
39 |
preprocess = T.Compose([
|
|
|
45 |
with torch.no_grad():
|
46 |
output = segmentation_model(input_tensor)["out"][0]
|
47 |
|
|
|
48 |
mask = output.argmax(0).byte().cpu().numpy()
|
49 |
raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
|
50 |
processed_mask = post_process_mask(raw_mask) # Apply post-processing
|
51 |
return processed_mask.resize(image.size)
|
52 |
|
|
|
53 |
def post_process_mask(mask):
|
|
|
54 |
mask_np = np.array(mask)
|
55 |
|
|
|
56 |
kernel = np.ones((5, 5), np.uint8)
|
57 |
mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
|
58 |
mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
|
59 |
|
|
|
60 |
processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
|
61 |
return processed_mask
|
62 |
|
|
|
63 |
def get_bounding_box(mask):
|
64 |
mask_np = np.array(mask)
|
65 |
coords = np.column_stack(np.where(mask_np > 0))
|
|
|
69 |
x_max, y_max = coords.max(axis=0)
|
70 |
return (x_min, y_min, x_max, y_max)
|
71 |
|
|
|
72 |
def visualize_mask(image, mask):
|
73 |
overlay = image.copy().convert("RGBA")
|
74 |
draw = ImageDraw.Draw(overlay)
|
|
|
79 |
return blended
|
80 |
|
81 |
def overlay_design(cloth_image, design_image):
|
|
|
82 |
cloth_image = cloth_image.convert("RGBA")
|
83 |
design_image = design_image.convert("RGBA")
|
84 |
|
|
|
85 |
mask = get_tshirt_mask(cloth_image)
|
86 |
|
|
|
87 |
bbox = get_bounding_box(mask)
|
88 |
tshirt_width = bbox[2] - bbox[0]
|
89 |
tshirt_height = bbox[3] - bbox[1]
|
90 |
|
|
|
91 |
scale_factor = 0.7
|
92 |
design_width = int(tshirt_width * scale_factor)
|
93 |
design_height = int(tshirt_height * scale_factor)
|
94 |
+
design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
|
95 |
|
|
|
96 |
x_offset = bbox[0] + (tshirt_width - design_width) // 2
|
97 |
y_offset = bbox[1] + (tshirt_height - design_height) // 2
|
98 |
|
|
|
99 |
transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
|
100 |
transparent_layer.paste(design_image, (x_offset, y_offset), design_image)
|
101 |
|
|
|
102 |
final_image = Image.alpha_composite(cloth_image, transparent_layer)
|
103 |
return final_image
|
104 |
|
105 |
def debug_intermediate_outputs(cloth_image, mask):
|
|
|
106 |
cloth_image.save("debug_cloth_image.png")
|
107 |
mask.save("debug_tshirt_mask.png")
|
108 |
|
|
|
118 |
except Exception as e:
|
119 |
raise Exception(f"Error in design process: {str(e)}")
|
120 |
|
|
|
121 |
with gr.Blocks() as interface:
|
122 |
gr.Markdown("# **AI Cloth Designer**")
|
123 |
gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.")
|
124 |
with gr.Row():
|
125 |
+
with gr.Column():
|
126 |
color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
|
127 |
design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
|
128 |
generate_button = gr.Button("Generate T-Shirt")
|
129 |
+
with gr.Column():
|
130 |
+
output_image = gr.Image(label="Final T-Shirt Design")
|
131 |
|
132 |
generate_button.click(
|
133 |
design_tshirt,
|