gaur3009 commited on
Commit
0d2713b
·
verified ·
1 Parent(s): 832d9a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -80
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from PIL import Image, ImageFilter
3
  import requests
4
  from io import BytesIO
5
  import torch
@@ -11,121 +11,152 @@ import cv2
11
  # AI model repo for design generation
12
  repo = "artificialguybr/TshirtDesignRedmond-V2"
13
 
14
- # ---------- Step 1: Generate Plain T-shirt ----------
15
  def generate_cloth(color_prompt):
16
- prompt = f"A plain {color_prompt.strip().lower()} colored T-shirt hanging on a plain wall."
17
  api_url = f"https://api-inference.huggingface.co/models/{repo}"
 
18
  payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
19
-
20
- try:
21
- response = requests.post(api_url, json=payload, timeout=60)
22
- response.raise_for_status()
23
- image_data = response.content
24
- return Image.open(BytesIO(image_data)).convert("RGB")
25
- except Exception as e:
26
- print(f"Error generating cloth: {e}")
27
- raise Exception("Failed to generate the T-shirt. Check your API key and internet connection.")
28
 
29
- # ---------- Step 2: Generate Design ----------
30
  def generate_design(design_prompt):
31
- prompt = f"A bold {design_prompt.strip().lower()} design with vibrant colors, highly detailed."
32
  api_url = f"https://api-inference.huggingface.co/models/{repo}"
 
33
  payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
34
-
35
- try:
36
- response = requests.post(api_url, json=payload, timeout=60)
37
- response.raise_for_status()
38
- image_data = response.content
39
- return Image.open(BytesIO(image_data)).convert("RGBA")
40
- except Exception as e:
41
- print(f"Error generating design: {e}")
42
- raise Exception("Failed to generate the design. Check your API key and internet connection.")
43
 
44
- # ---------- Step 3: Load Segmentation Model ----------
45
- try:
46
- segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
47
- except Exception as e:
48
- print(f"Error loading segmentation model: {e}")
49
- raise Exception("Failed to load the segmentation model. Check your PyTorch installation.")
50
 
 
51
  def get_tshirt_mask(image):
52
- transform = T.Compose([
53
- T.Resize(520),
 
54
  T.ToTensor(),
55
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
56
  ])
57
- input_tensor = transform(image).unsqueeze(0)
58
-
59
  with torch.no_grad():
60
- output = segmentation_model(input_tensor)['out'][0]
61
 
 
62
  mask = output.argmax(0).byte().cpu().numpy()
63
- mask_img = Image.fromarray((mask == 0).astype('uint8') * 255).convert("L")
64
-
65
- return mask_img.filter(ImageFilter.GaussianBlur(10)).resize(image.size)
66
 
67
- # ---------- Step 4: Fit Design to T-shirt Area ----------
68
- def fit_design_to_tshirt(design_image, mask):
 
69
  mask_np = np.array(mask)
70
- contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
71
 
72
- if contours:
73
- # Largest contour represents the T-shirt
74
- c = max(contours, key=cv2.contourArea)
75
- x, y, w, h = cv2.boundingRect(c)
76
-
77
- # Log measurements of the T-shirt area
78
- print(f"T-shirt Area Detected at (x: {x}, y: {y}), Width: {w}px, Height: {h}px")
79
-
80
- # Resize design to fit within the T-shirt area
81
- design_resized = design_image.resize((w, h), Image.Resampling.LANCZOS)
82
-
83
- # Overlay design on a transparent canvas of the same size
84
- canvas = Image.new("RGBA", mask.size, (0, 0, 0, 0))
85
- canvas.paste(design_resized, (x, y), design_resized)
86
-
87
- return canvas
88
- else:
89
- print("No contour found on the mask.")
90
- return design_image
91
-
92
- # ---------- Step 5: Overlay Design ----------
 
 
 
 
 
 
 
 
 
93
  def overlay_design(cloth_image, design_image):
 
 
 
 
 
94
  mask = get_tshirt_mask(cloth_image)
95
- fitted_design = fit_design_to_tshirt(design_image, mask)
96
- return Image.alpha_composite(cloth_image.convert("RGBA"), fitted_design)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # ---------- Step 6: Final Design Pipeline ----------
99
  def design_tshirt(color_prompt, design_prompt):
 
 
100
  try:
101
- cloth_image = generate_cloth(color_prompt)
102
- design_image = generate_design(design_prompt)
 
103
  final_image = overlay_design(cloth_image, design_image)
104
- return cloth_image, design_image, final_image
105
  except Exception as e:
106
- print(f"Error: {e}")
107
- return None, None, None
108
 
109
- # ---------- Step 7: Gradio Interface ----------
110
  with gr.Blocks() as interface:
111
- gr.Markdown("# **AI T-Shirt Designer (Debug Mode)**")
112
- gr.Markdown("Generate a T-shirt with a design that perfectly fits inside the T-shirt boundaries.")
113
-
114
  with gr.Row():
115
  with gr.Column():
116
  color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
117
  design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
118
  generate_button = gr.Button("Generate T-Shirt")
119
-
120
  with gr.Column():
121
- plain_tshirt_output = gr.Image(label="Generated Plain T-Shirt")
122
- design_output = gr.Image(label="Generated Design")
123
- final_output = gr.Image(label="Final Fitted T-Shirt Design")
124
 
125
  generate_button.click(
126
  design_tshirt,
127
  inputs=[color_prompt, design_prompt],
128
- outputs=[plain_tshirt_output, design_output, final_output],
129
  )
130
 
131
- interface.launch(debug=True, share=True)
 
1
  import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFilter
3
  import requests
4
  from io import BytesIO
5
  import torch
 
11
  # AI model repo for design generation
12
  repo = "artificialguybr/TshirtDesignRedmond-V2"
13
 
 
14
  def generate_cloth(color_prompt):
15
+ prompt = f"A plain {color_prompt} colored T-shirt hanging on a plain wall."
16
  api_url = f"https://api-inference.huggingface.co/models/{repo}"
17
+ headers = {}
18
  payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
19
+ response = requests.post(api_url, headers=headers, json=payload)
20
+ if response.status_code == 200:
21
+ return Image.open(BytesIO(response.content)).convert("RGB")
22
+ else:
23
+ raise Exception(f"Error generating cloth: {response.status_code}")
 
 
 
 
24
 
 
25
  def generate_design(design_prompt):
26
+ prompt = f"A bold {design_prompt} design with vibrant colors, highly detailed."
27
  api_url = f"https://api-inference.huggingface.co/models/{repo}"
28
+ headers = {}
29
  payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}}
30
+ response = requests.post(api_url, headers=headers, json=payload)
31
+ if response.status_code == 200:
32
+ return Image.open(BytesIO(response.content)).convert("RGBA")
33
+ else:
34
+ raise Exception(f"Error generating design: {response.status_code}")
 
 
 
 
35
 
36
+ # Load pretrained DeepLabV3 model for T-shirt segmentation
37
+ segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
 
 
 
 
38
 
39
+ # Apply segmentation to extract T-shirt mask
40
  def get_tshirt_mask(image):
41
+ image = image.convert("RGB") # Ensure 3 channels
42
+ preprocess = T.Compose([
43
+ T.Resize((520, 520)), # Resize to avoid distortion
44
  T.ToTensor(),
45
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
  ])
47
+ input_tensor = preprocess(image).unsqueeze(0)
 
48
  with torch.no_grad():
49
+ output = segmentation_model(input_tensor)["out"][0]
50
 
51
+ # Extract T-shirt mask (class 15 in COCO dataset)
52
  mask = output.argmax(0).byte().cpu().numpy()
53
+ raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
54
+ processed_mask = post_process_mask(raw_mask) # Apply post-processing
55
+ return processed_mask.resize(image.size)
56
 
57
+ # Post-process mask to improve quality
58
+ def post_process_mask(mask):
59
+ # Convert mask to NumPy array
60
  mask_np = np.array(mask)
 
61
 
62
+ # Morphological operations to refine mask
63
+ kernel = np.ones((5, 5), np.uint8)
64
+ mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
65
+ mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
66
+
67
+ # Convert back to PIL image and smooth
68
+ processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
69
+ return processed_mask
70
+
71
+ # Get bounding box from mask
72
+ def get_bounding_box(mask):
73
+ mask_np = np.array(mask)
74
+ coords = np.column_stack(np.where(mask_np > 0))
75
+ if coords.size == 0:
76
+ raise Exception("No T-shirt detected in the image.")
77
+ x_min, y_min = coords.min(axis=0)
78
+ x_max, y_max = coords.max(axis=0)
79
+ return (x_min, y_min, x_max, y_max)
80
+
81
+ # Visualize mask and bounding box on the image for debugging
82
+ def visualize_mask(image, mask):
83
+ overlay = image.copy().convert("RGBA")
84
+ draw = ImageDraw.Draw(overlay)
85
+ bbox = get_bounding_box(mask)
86
+ draw.rectangle(bbox, outline="red", width=3) # Draw bounding box
87
+ blended = Image.blend(image.convert("RGBA"), overlay, alpha=0.5) # Overlay mask
88
+ blended.save("debug_visualization.png") # Save debug image
89
+ return blended
90
+
91
+ # Overlay design on the T-shirt
92
  def overlay_design(cloth_image, design_image):
93
+ # Ensure images are in RGBA mode
94
+ cloth_image = cloth_image.convert("RGBA")
95
+ design_image = design_image.convert("RGBA")
96
+
97
+ # Generate T-shirt mask
98
  mask = get_tshirt_mask(cloth_image)
99
+
100
+ # Extract bounding box for precise placement
101
+ bbox = get_bounding_box(mask)
102
+ tshirt_width = bbox[2] - bbox[0]
103
+ tshirt_height = bbox[3] - bbox[1]
104
+
105
+ # Resize the design to fit the T-shirt
106
+ design_width = int(tshirt_width * 0.6)
107
+ design_height = int(tshirt_height * 0.6)
108
+ resized_design = design_image.resize((design_width, design_height))
109
+
110
+ # Position the design in the center of the T-shirt
111
+ design_position = (
112
+ bbox[0] + (tshirt_width - design_width) // 2,
113
+ bbox[1] + (tshirt_height - design_height) // 2
114
+ )
115
+
116
+ # Create a transparent layer for the design
117
+ transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
118
+ transparent_layer.paste(resized_design, design_position, resized_design)
119
+
120
+ # Mask the design to the T-shirt area
121
+ masked_design = Image.composite(transparent_layer, Image.new("RGBA", cloth_image.size), mask)
122
+
123
+ # Combine the cloth image with the masked design
124
+ final_image = Image.alpha_composite(cloth_image, masked_design)
125
+ return final_image
126
+
127
+ def debug_intermediate_outputs(cloth_image, mask):
128
+ # Save debug images
129
+ cloth_image.save("debug_cloth_image.png")
130
+ mask.save("debug_tshirt_mask.png")
131
 
 
132
  def design_tshirt(color_prompt, design_prompt):
133
+ cloth_image = generate_cloth(color_prompt)
134
+ design_image = generate_design(design_prompt)
135
  try:
136
+ mask = get_tshirt_mask(cloth_image)
137
+ debug_intermediate_outputs(cloth_image, mask) # Debugging
138
+ visualize_mask(cloth_image, mask) # Save visualization
139
  final_image = overlay_design(cloth_image, design_image)
140
+ return final_image
141
  except Exception as e:
142
+ raise Exception(f"Error in design process: {str(e)}")
 
143
 
144
+ # Gradio UI
145
  with gr.Blocks() as interface:
146
+ gr.Markdown("# **AI Cloth Designer**")
147
+ gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.")
 
148
  with gr.Row():
149
  with gr.Column():
150
  color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
151
  design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
152
  generate_button = gr.Button("Generate T-Shirt")
 
153
  with gr.Column():
154
+ output_image = gr.Image(label="Final T-Shirt Design")
 
 
155
 
156
  generate_button.click(
157
  design_tshirt,
158
  inputs=[color_prompt, design_prompt],
159
+ outputs=output_image,
160
  )
161
 
162
+ interface.launch(debug=True)