sidhtang commited on
Commit
3f0cefe
·
verified ·
1 Parent(s): b9eb7d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import mediapipe as mp
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ import gradio as gr
9
+ from enum import Enum
10
+ import colorsys
11
+ from typing import Tuple, Dict
12
+ import torch.nn.functional as F
13
+
14
+ class ClothingType(Enum):
15
+ SHIRT = "shirt"
16
+ PANTS = "pants"
17
+ DRESS = "dress"
18
+ JACKET = "jacket"
19
+
20
+ class BodySegmentation(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ # Load DeepLab v3+ for semantic segmentation
24
+ self.model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
25
+ self.model.eval()
26
+
27
+ def forward(self, x):
28
+ return self.model(x)['out']
29
+
30
+ class VirtualTryOn:
31
+ def __init__(self):
32
+ # Initialize MediaPipe
33
+ self.mp_pose = mp.solutions.pose
34
+ self.mp_holistic = mp.solutions.holistic
35
+ self.pose = self.mp_pose.Pose(
36
+ static_image_mode=True,
37
+ model_complexity=2,
38
+ min_detection_confidence=0.5
39
+ )
40
+ self.holistic = self.mp_holistic.Holistic(
41
+ static_image_mode=True,
42
+ model_complexity=2,
43
+ min_detection_confidence=0.5
44
+ )
45
+
46
+ # Initialize body segmentation
47
+ self.segmentation = BodySegmentation()
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ self.segmentation.to(self.device)
50
+
51
+ # Image transforms
52
+ self.transforms = transforms.Compose([
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
55
+ std=[0.229, 0.224, 0.225])
56
+ ])
57
+
58
+ def get_body_segmentation(self, image: np.ndarray) -> np.ndarray:
59
+ """
60
+ Get precise body segmentation mask
61
+ """
62
+ # Prepare image for model
63
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
64
+ input_tensor = self.transforms(pil_image).unsqueeze(0).to(self.device)
65
+
66
+ # Get segmentation mask
67
+ with torch.no_grad():
68
+ output = self.segmentation(input_tensor)
69
+ mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()
70
+
71
+ # Person class is typically index 15 in COCO dataset
72
+ return (mask == 15).astype(np.uint8)
73
+
74
+ def estimate_lighting(self, image: np.ndarray) -> Dict[str, float]:
75
+ """
76
+ Estimate lighting conditions from the image
77
+ """
78
+ # Convert to HSV
79
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
80
+
81
+ # Get average brightness and saturation
82
+ brightness = np.mean(hsv[:, :, 2])
83
+ saturation = np.mean(hsv[:, :, 1])
84
+
85
+ return {
86
+ 'brightness': brightness / 255.0,
87
+ 'saturation': saturation / 255.0
88
+ }
89
+
90
+ def adjust_clothing_color(self, clothing: np.ndarray,
91
+ lighting_params: Dict[str, float]) -> np.ndarray:
92
+ """
93
+ Adjust clothing colors to match lighting conditions
94
+ """
95
+ # Convert to HSV for easier adjustment
96
+ hsv = cv2.cvtColor(clothing, cv2.COLOR_BGR2HSV).astype(np.float32)
97
+
98
+ # Adjust brightness and saturation
99
+ hsv[:, :, 2] *= lighting_params['brightness']
100
+ hsv[:, :, 1] *= lighting_params['saturation']
101
+
102
+ # Ensure values are within valid range
103
+ hsv = np.clip(hsv, 0, 255).astype(np.uint8)
104
+
105
+ # Convert back to BGR
106
+ return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
107
+
108
+ def get_clothing_dimensions(self, landmarks, image_shape: Tuple[int, int],
109
+ clothing_type: ClothingType) -> Dict:
110
+ """
111
+ Get clothing dimensions based on body landmarks and clothing type
112
+ """
113
+ height, width = image_shape[:2]
114
+
115
+ if clothing_type in [ClothingType.SHIRT, ClothingType.JACKET]:
116
+ # For upper body clothing
117
+ left_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_SHOULDER]
118
+ right_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.RIGHT_SHOULDER]
119
+ left_hip = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_HIP]
120
+
121
+ shoulder_width = abs(right_shoulder.x - left_shoulder.x) * width
122
+ torso_height = abs(left_shoulder.y - left_hip.y) * height
123
+
124
+ return {
125
+ 'top_left': (
126
+ int(min(left_shoulder.x, right_shoulder.x) * width),
127
+ int(left_shoulder.y * height)
128
+ ),
129
+ 'width': int(shoulder_width * 1.3),
130
+ 'height': int(torso_height * 1.1)
131
+ }
132
+
133
+ elif clothing_type == ClothingType.PANTS:
134
+ # For pants
135
+ left_hip = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_HIP]
136
+ right_hip = landmarks.landmark[self.mp_pose.PoseLandmark.RIGHT_HIP]
137
+ left_ankle = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_ANKLE]
138
+
139
+ hip_width = abs(right_hip.x - left_hip.x) * width
140
+ leg_height = abs(left_hip.y - left_ankle.y) * height
141
+
142
+ return {
143
+ 'top_left': (
144
+ int(min(left_hip.x, right_hip.x) * width),
145
+ int(left_hip.y * height)
146
+ ),
147
+ 'width': int(hip_width * 1.5),
148
+ 'height': int(leg_height * 1.05)
149
+ }
150
+
151
+ elif clothing_type == ClothingType.DRESS:
152
+ # For dresses
153
+ left_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_SHOULDER]
154
+ right_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.RIGHT_SHOULDER]
155
+ left_knee = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_KNEE]
156
+
157
+ shoulder_width = abs(right_shoulder.x - left_shoulder.x) * width
158
+ dress_height = abs(left_shoulder.y - left_knee.y) * height
159
+
160
+ return {
161
+ 'top_left': (
162
+ int(min(left_shoulder.x, right_shoulder.x) * width),
163
+ int(left_shoulder.y * height)
164
+ ),
165
+ 'width': int(shoulder_width * 1.4),
166
+ 'height': int(dress_height * 1.1)
167
+ }
168
+
169
+ def try_on(self, person_image: np.ndarray, clothing_image: np.ndarray,
170
+ clothing_type: ClothingType) -> np.ndarray:
171
+ """
172
+ Enhanced try-on method with support for different clothing types
173
+ """
174
+ # Get body segmentation
175
+ body_mask = self.get_body_segmentation(person_image)
176
+
177
+ # Get pose landmarks
178
+ results = self.pose.process(cv2.cvtColor(person_image, cv2.COLOR_BGR2RGB))
179
+ if not results.pose_landmarks:
180
+ raise ValueError("No person detected in the image")
181
+
182
+ # Estimate lighting conditions
183
+ lighting_params = self.estimate_lighting(person_image)
184
+
185
+ # Adjust clothing colors
186
+ adjusted_clothing = self.adjust_clothing_color(clothing_image, lighting_params)
187
+
188
+ # Get clothing dimensions
189
+ dimensions = self.get_clothing_dimensions(
190
+ results.pose_landmarks,
191
+ person_image.shape,
192
+ clothing_type
193
+ )
194
+
195
+ # Resize clothing
196
+ clothing_resized = cv2.resize(
197
+ adjusted_clothing,
198
+ (dimensions['width'], dimensions['height']),
199
+ interpolation=cv2.INTER_AREA
200
+ )
201
+
202
+ # Create alpha mask for smooth blending
203
+ if clothing_resized.shape[2] == 4:
204
+ alpha_channel = clothing_resized[:, :, 3] / 255.0
205
+ else:
206
+ alpha_channel = np.ones(clothing_resized.shape[:2])
207
+
208
+ alpha_3channel = np.stack([alpha_channel] * 3, axis=2)
209
+
210
+ # Calculate placement coordinates
211
+ y1 = dimensions['top_left'][1]
212
+ y2 = y1 + dimensions['height']
213
+ x1 = dimensions['top_left'][0]
214
+ x2 = x1 + dimensions['width']
215
+
216
+ # Ensure coordinates are within image boundaries
217
+ y1 = max(0, y1)
218
+ y2 = min(person_image.shape[0], y2)
219
+ x1 = max(0, x1)
220
+ x2 = min(person_image.shape[1], x2)
221
+
222
+ # Apply body mask to improve blending
223
+ body_mask_roi = body_mask[y1:y2, x1:x2]
224
+ alpha_3channel = alpha_3channel * np.expand_dims(body_mask_roi, axis=2)
225
+
226
+ # Blend images
227
+ roi = person_image[y1:y2, x1:x2]
228
+ clothing_rgb = clothing_resized[:, :, :3]
229
+ blended = (1 - alpha_3channel) * roi + alpha_3channel * clothing_rgb[:roi.shape[0], :roi.shape[1]]
230
+
231
+ result = person_image.copy()
232
+ result[y1:y2, x1:x2] = blended
233
+
234
+ return result
235
+
236
+ def create_gradio_interface():
237
+ def process_images(person_img, clothing_img, clothing_type):
238
+ try_on = VirtualTryOn()
239
+
240
+ # Convert clothing type string to enum
241
+ clothing_type_enum = ClothingType(clothing_type.lower())
242
+
243
+ # Process the images
244
+ result = try_on.try_on(person_img, clothing_img, clothing_type_enum)
245
+
246
+ return result
247
+
248
+ # Create the interface
249
+ iface = gr.Interface(
250
+ fn=process_images,
251
+ inputs=[
252
+ gr.Image(label="Upload Person Image"),
253
+ gr.Image(label="Upload Clothing Image"),
254
+ gr.Dropdown(
255
+ choices=["Shirt", "Pants", "Dress", "Jacket"],
256
+ label="Select Clothing Type"
257
+ )
258
+ ],
259
+ outputs=gr.Image(label="Result"),
260
+ title="Virtual Try-On System",
261
+ description="Upload a person's image and a clothing item to see how it looks!",
262
+ examples=[
263
+ ["person.jpg", "shirt.png", "Shirt"],
264
+ ["person.jpg", "pants.png", "Pants"]
265
+ ]
266
+ )
267
+
268
+ return iface
269
+
270
+ if __name__ == "__main__":
271
+ iface = create_gradio_interface()
272
+ iface.launch()