gaur3009 commited on
Commit
f869f83
·
verified ·
1 Parent(s): d273e04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -13
app.py CHANGED
@@ -2,10 +2,11 @@ import gradio as gr
2
  from PIL import Image, ImageDraw, ImageFilter
3
  import requests
4
  from io import BytesIO
5
- import numpy as np
6
  import torch
7
  import torchvision.transforms as T
8
  from torchvision import models
 
 
9
 
10
  # AI model repo for design generation
11
  repo = "artificialguybr/TshirtDesignRedmond-V2"
@@ -32,21 +33,83 @@ def generate_design(design_prompt):
32
  else:
33
  raise Exception(f"Error generating design: {response.status_code}")
34
 
35
- def overlay_design_on_tshirt(cloth_image, design_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Ensure images are in RGBA mode
37
  cloth_image = cloth_image.convert("RGBA")
38
  design_image = design_image.convert("RGBA")
39
 
40
- # Resize the design to fit within 70% of the T-shirt's dimensions
 
 
 
 
 
 
 
 
41
  scale_factor = 0.7
42
- width, height = cloth_image.size
43
- design_width = int(width * scale_factor)
44
- design_height = int(height * scale_factor)
45
- design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
46
 
47
  # Calculate position to center the design on the T-shirt
48
- x_offset = (width - design_width) // 2
49
- y_offset = (height - design_height) // 2
50
 
51
  # Create a transparent layer for the design
52
  transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
@@ -56,11 +119,19 @@ def overlay_design_on_tshirt(cloth_image, design_image):
56
  final_image = Image.alpha_composite(cloth_image, transparent_layer)
57
  return final_image
58
 
 
 
 
 
 
59
  def design_tshirt(color_prompt, design_prompt):
 
 
60
  try:
61
- cloth_image = generate_cloth(color_prompt)
62
- design_image = generate_design(design_prompt)
63
- final_image = overlay_design_on_tshirt(cloth_image, design_image)
 
64
  return final_image
65
  except Exception as e:
66
  raise Exception(f"Error in design process: {str(e)}")
@@ -83,4 +154,4 @@ with gr.Blocks() as interface:
83
  outputs=output_image,
84
  )
85
 
86
- interface.launch(debug=True)
 
2
  from PIL import Image, ImageDraw, ImageFilter
3
  import requests
4
  from io import BytesIO
 
5
  import torch
6
  import torchvision.transforms as T
7
  from torchvision import models
8
+ import numpy as np
9
+ import cv2
10
 
11
  # AI model repo for design generation
12
  repo = "artificialguybr/TshirtDesignRedmond-V2"
 
33
  else:
34
  raise Exception(f"Error generating design: {response.status_code}")
35
 
36
+ # Load pretrained DeepLabV3 model for T-shirt segmentation
37
+ segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
38
+
39
+ # Apply segmentation to extract T-shirt mask
40
+ def get_tshirt_mask(image):
41
+ image = image.convert("RGB") # Ensure 3 channels
42
+ preprocess = T.Compose([
43
+ T.Resize((520, 520)), # Resize to avoid distortion
44
+ T.ToTensor(),
45
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
+ ])
47
+ input_tensor = preprocess(image).unsqueeze(0)
48
+ with torch.no_grad():
49
+ output = segmentation_model(input_tensor)["out"][0]
50
+
51
+ # Extract T-shirt mask (class 15 in COCO dataset)
52
+ mask = output.argmax(0).byte().cpu().numpy()
53
+ raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
54
+ processed_mask = post_process_mask(raw_mask) # Apply post-processing
55
+ return processed_mask.resize(image.size)
56
+
57
+ # Post-process mask to improve quality
58
+ def post_process_mask(mask):
59
+ # Convert mask to NumPy array
60
+ mask_np = np.array(mask)
61
+
62
+ # Morphological operations to refine mask
63
+ kernel = np.ones((5, 5), np.uint8)
64
+ mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
65
+ mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
66
+
67
+ # Convert back to PIL image and smooth
68
+ processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
69
+ return processed_mask
70
+
71
+ # Get bounding box from mask
72
+ def get_bounding_box(mask):
73
+ mask_np = np.array(mask)
74
+ coords = np.column_stack(np.where(mask_np > 0))
75
+ if coords.size == 0:
76
+ raise Exception("No T-shirt detected in the image.")
77
+ x_min, y_min = coords.min(axis=0)
78
+ x_max, y_max = coords.max(axis=0)
79
+ return (x_min, y_min, x_max, y_max)
80
+
81
+ # Visualize mask and bounding box on the image for debugging
82
+ def visualize_mask(image, mask):
83
+ overlay = image.copy().convert("RGBA")
84
+ draw = ImageDraw.Draw(overlay)
85
+ bbox = get_bounding_box(mask)
86
+ draw.rectangle(bbox, outline="red", width=3) # Draw bounding box
87
+ blended = Image.blend(image.convert("RGBA"), overlay, alpha=0.5) # Overlay mask
88
+ blended.save("debug_visualization.png") # Save debug image
89
+ return blended
90
+
91
+ def overlay_design(cloth_image, design_image):
92
  # Ensure images are in RGBA mode
93
  cloth_image = cloth_image.convert("RGBA")
94
  design_image = design_image.convert("RGBA")
95
 
96
+ # Generate T-shirt mask
97
+ mask = get_tshirt_mask(cloth_image)
98
+
99
+ # Extract bounding box and T-shirt coordinates
100
+ bbox = get_bounding_box(mask)
101
+ tshirt_width = bbox[2] - bbox[0]
102
+ tshirt_height = bbox[3] - bbox[1]
103
+
104
+ # Resize the design to fit within 70% of the T-shirt bounding box
105
  scale_factor = 0.7
106
+ design_width = int(tshirt_width * scale_factor)
107
+ design_height = int(tshirt_height * scale_factor)
108
+ design_image = design_image.resize((design_width, design_height), Image.ANTIALIAS)
 
109
 
110
  # Calculate position to center the design on the T-shirt
111
+ x_offset = bbox[0] + (tshirt_width - design_width) // 2
112
+ y_offset = bbox[1] + (tshirt_height - design_height) // 2
113
 
114
  # Create a transparent layer for the design
115
  transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
 
119
  final_image = Image.alpha_composite(cloth_image, transparent_layer)
120
  return final_image
121
 
122
+ def debug_intermediate_outputs(cloth_image, mask):
123
+ # Save debug images
124
+ cloth_image.save("debug_cloth_image.png")
125
+ mask.save("debug_tshirt_mask.png")
126
+
127
  def design_tshirt(color_prompt, design_prompt):
128
+ cloth_image = generate_cloth(color_prompt)
129
+ design_image = generate_design(design_prompt)
130
  try:
131
+ mask = get_tshirt_mask(cloth_image)
132
+ debug_intermediate_outputs(cloth_image, mask) # Debugging
133
+ visualize_mask(cloth_image, mask) # Save visualization
134
+ final_image = overlay_design(cloth_image, design_image)
135
  return final_image
136
  except Exception as e:
137
  raise Exception(f"Error in design process: {str(e)}")
 
154
  outputs=output_image,
155
  )
156
 
157
+ interface.launch(debug=True)