gaur3009 commited on
Commit
074903a
·
verified ·
1 Parent(s): 1420ac5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -23
app.py CHANGED
@@ -8,7 +8,6 @@ from torchvision import models
8
  import numpy as np
9
  import cv2
10
 
11
- # AI model repo for design generation
12
  repo = "artificialguybr/TshirtDesignRedmond-V2"
13
 
14
  def generate_cloth(color_prompt):
@@ -33,10 +32,8 @@ def generate_design(design_prompt):
33
  else:
34
  raise Exception(f"Error generating design: {response.status_code}")
35
 
36
- # Load pretrained DeepLabV3 model for T-shirt segmentation
37
  segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
38
 
39
- # Apply segmentation to extract T-shirt mask
40
  def get_tshirt_mask(image):
41
  image = image.convert("RGB") # Ensure 3 channels
42
  preprocess = T.Compose([
@@ -48,27 +45,21 @@ def get_tshirt_mask(image):
48
  with torch.no_grad():
49
  output = segmentation_model(input_tensor)["out"][0]
50
 
51
- # Extract T-shirt mask (class 15 in COCO dataset)
52
  mask = output.argmax(0).byte().cpu().numpy()
53
  raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
54
  processed_mask = post_process_mask(raw_mask) # Apply post-processing
55
  return processed_mask.resize(image.size)
56
 
57
- # Post-process mask to improve quality
58
  def post_process_mask(mask):
59
- # Convert mask to NumPy array
60
  mask_np = np.array(mask)
61
 
62
- # Morphological operations to refine mask
63
  kernel = np.ones((5, 5), np.uint8)
64
  mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
65
  mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
66
 
67
- # Convert back to PIL image and smooth
68
  processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
69
  return processed_mask
70
 
71
- # Get bounding box from mask
72
  def get_bounding_box(mask):
73
  mask_np = np.array(mask)
74
  coords = np.column_stack(np.where(mask_np > 0))
@@ -78,7 +69,6 @@ def get_bounding_box(mask):
78
  x_max, y_max = coords.max(axis=0)
79
  return (x_min, y_min, x_max, y_max)
80
 
81
- # Visualize mask and bounding box on the image for debugging
82
  def visualize_mask(image, mask):
83
  overlay = image.copy().convert("RGBA")
84
  draw = ImageDraw.Draw(overlay)
@@ -89,38 +79,30 @@ def visualize_mask(image, mask):
89
  return blended
90
 
91
  def overlay_design(cloth_image, design_image):
92
- # Ensure images are in RGBA mode
93
  cloth_image = cloth_image.convert("RGBA")
94
  design_image = design_image.convert("RGBA")
95
 
96
- # Generate T-shirt mask
97
  mask = get_tshirt_mask(cloth_image)
98
 
99
- # Extract bounding box and T-shirt coordinates
100
  bbox = get_bounding_box(mask)
101
  tshirt_width = bbox[2] - bbox[0]
102
  tshirt_height = bbox[3] - bbox[1]
103
 
104
- # Resize the design to fit within 70% of the T-shirt bounding box
105
  scale_factor = 0.7
106
  design_width = int(tshirt_width * scale_factor)
107
  design_height = int(tshirt_height * scale_factor)
108
- design_image = design_image.resize((design_width, design_height), Image.ANTIALIAS)
109
 
110
- # Calculate position to center the design on the T-shirt
111
  x_offset = bbox[0] + (tshirt_width - design_width) // 2
112
  y_offset = bbox[1] + (tshirt_height - design_height) // 2
113
 
114
- # Create a transparent layer for the design
115
  transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
116
  transparent_layer.paste(design_image, (x_offset, y_offset), design_image)
117
 
118
- # Blend the T-shirt with the design
119
  final_image = Image.alpha_composite(cloth_image, transparent_layer)
120
  return final_image
121
 
122
  def debug_intermediate_outputs(cloth_image, mask):
123
- # Save debug images
124
  cloth_image.save("debug_cloth_image.png")
125
  mask.save("debug_tshirt_mask.png")
126
 
@@ -136,17 +118,16 @@ def design_tshirt(color_prompt, design_prompt):
136
  except Exception as e:
137
  raise Exception(f"Error in design process: {str(e)}")
138
 
139
- # Gradio UI
140
  with gr.Blocks() as interface:
141
  gr.Markdown("# **AI Cloth Designer**")
142
  gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.")
143
  with gr.Row():
144
- with gr.Column(scale=1):
145
  color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
146
  design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
147
  generate_button = gr.Button("Generate T-Shirt")
148
- with gr.Column(scale=4):
149
- output_image = gr.Image(label="Final T-Shirt Design", type="pil", height=500)
150
 
151
  generate_button.click(
152
  design_tshirt,
 
8
  import numpy as np
9
  import cv2
10
 
 
11
  repo = "artificialguybr/TshirtDesignRedmond-V2"
12
 
13
  def generate_cloth(color_prompt):
 
32
  else:
33
  raise Exception(f"Error generating design: {response.status_code}")
34
 
 
35
  segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
36
 
 
37
  def get_tshirt_mask(image):
38
  image = image.convert("RGB") # Ensure 3 channels
39
  preprocess = T.Compose([
 
45
  with torch.no_grad():
46
  output = segmentation_model(input_tensor)["out"][0]
47
 
 
48
  mask = output.argmax(0).byte().cpu().numpy()
49
  raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask
50
  processed_mask = post_process_mask(raw_mask) # Apply post-processing
51
  return processed_mask.resize(image.size)
52
 
 
53
  def post_process_mask(mask):
 
54
  mask_np = np.array(mask)
55
 
 
56
  kernel = np.ones((5, 5), np.uint8)
57
  mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask
58
  mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise
59
 
 
60
  processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3))
61
  return processed_mask
62
 
 
63
  def get_bounding_box(mask):
64
  mask_np = np.array(mask)
65
  coords = np.column_stack(np.where(mask_np > 0))
 
69
  x_max, y_max = coords.max(axis=0)
70
  return (x_min, y_min, x_max, y_max)
71
 
 
72
  def visualize_mask(image, mask):
73
  overlay = image.copy().convert("RGBA")
74
  draw = ImageDraw.Draw(overlay)
 
79
  return blended
80
 
81
  def overlay_design(cloth_image, design_image):
 
82
  cloth_image = cloth_image.convert("RGBA")
83
  design_image = design_image.convert("RGBA")
84
 
 
85
  mask = get_tshirt_mask(cloth_image)
86
 
 
87
  bbox = get_bounding_box(mask)
88
  tshirt_width = bbox[2] - bbox[0]
89
  tshirt_height = bbox[3] - bbox[1]
90
 
 
91
  scale_factor = 0.7
92
  design_width = int(tshirt_width * scale_factor)
93
  design_height = int(tshirt_height * scale_factor)
94
+ design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
95
 
 
96
  x_offset = bbox[0] + (tshirt_width - design_width) // 2
97
  y_offset = bbox[1] + (tshirt_height - design_height) // 2
98
 
 
99
  transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0))
100
  transparent_layer.paste(design_image, (x_offset, y_offset), design_image)
101
 
 
102
  final_image = Image.alpha_composite(cloth_image, transparent_layer)
103
  return final_image
104
 
105
  def debug_intermediate_outputs(cloth_image, mask):
 
106
  cloth_image.save("debug_cloth_image.png")
107
  mask.save("debug_tshirt_mask.png")
108
 
 
118
  except Exception as e:
119
  raise Exception(f"Error in design process: {str(e)}")
120
 
 
121
  with gr.Blocks() as interface:
122
  gr.Markdown("# **AI Cloth Designer**")
123
  gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.")
124
  with gr.Row():
125
+ with gr.Column():
126
  color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue")
127
  design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns")
128
  generate_button = gr.Button("Generate T-Shirt")
129
+ with gr.Column():
130
+ output_image = gr.Image(label="Final T-Shirt Design")
131
 
132
  generate_button.click(
133
  design_tshirt,