File size: 9,909 Bytes
bd5362c
3f0cefe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

import cv2
import numpy as np
import mediapipe as mp
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from enum import Enum
import colorsys
from typing import Tuple, Dict
import torch.nn.functional as F

class ClothingType(Enum):
    SHIRT = "shirt"
    PANTS = "pants"
    DRESS = "dress"
    JACKET = "jacket"

class BodySegmentation(nn.Module):
    def __init__(self):
        super().__init__()
        # Load DeepLab v3+ for semantic segmentation
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
        self.model.eval()
        
    def forward(self, x):
        return self.model(x)['out']

class VirtualTryOn:
    def __init__(self):
        # Initialize MediaPipe
        self.mp_pose = mp.solutions.pose
        self.mp_holistic = mp.solutions.holistic
        self.pose = self.mp_pose.Pose(
            static_image_mode=True,
            model_complexity=2,
            min_detection_confidence=0.5
        )
        self.holistic = self.mp_holistic.Holistic(
            static_image_mode=True,
            model_complexity=2,
            min_detection_confidence=0.5
        )
        
        # Initialize body segmentation
        self.segmentation = BodySegmentation()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.segmentation.to(self.device)
        
        # Image transforms
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
    
    def get_body_segmentation(self, image: np.ndarray) -> np.ndarray:
        """
        Get precise body segmentation mask
        """
        # Prepare image for model
        pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        input_tensor = self.transforms(pil_image).unsqueeze(0).to(self.device)
        
        # Get segmentation mask
        with torch.no_grad():
            output = self.segmentation(input_tensor)
            mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()
            
        # Person class is typically index 15 in COCO dataset
        return (mask == 15).astype(np.uint8)
    
    def estimate_lighting(self, image: np.ndarray) -> Dict[str, float]:
        """
        Estimate lighting conditions from the image
        """
        # Convert to HSV
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        
        # Get average brightness and saturation
        brightness = np.mean(hsv[:, :, 2])
        saturation = np.mean(hsv[:, :, 1])
        
        return {
            'brightness': brightness / 255.0,
            'saturation': saturation / 255.0
        }
    
    def adjust_clothing_color(self, clothing: np.ndarray, 
                            lighting_params: Dict[str, float]) -> np.ndarray:
        """
        Adjust clothing colors to match lighting conditions
        """
        # Convert to HSV for easier adjustment
        hsv = cv2.cvtColor(clothing, cv2.COLOR_BGR2HSV).astype(np.float32)
        
        # Adjust brightness and saturation
        hsv[:, :, 2] *= lighting_params['brightness']
        hsv[:, :, 1] *= lighting_params['saturation']
        
        # Ensure values are within valid range
        hsv = np.clip(hsv, 0, 255).astype(np.uint8)
        
        # Convert back to BGR
        return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    
    def get_clothing_dimensions(self, landmarks, image_shape: Tuple[int, int], 
                              clothing_type: ClothingType) -> Dict:
        """
        Get clothing dimensions based on body landmarks and clothing type
        """
        height, width = image_shape[:2]
        
        if clothing_type in [ClothingType.SHIRT, ClothingType.JACKET]:
            # For upper body clothing
            left_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_SHOULDER]
            right_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.RIGHT_SHOULDER]
            left_hip = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_HIP]
            
            shoulder_width = abs(right_shoulder.x - left_shoulder.x) * width
            torso_height = abs(left_shoulder.y - left_hip.y) * height
            
            return {
                'top_left': (
                    int(min(left_shoulder.x, right_shoulder.x) * width),
                    int(left_shoulder.y * height)
                ),
                'width': int(shoulder_width * 1.3),
                'height': int(torso_height * 1.1)
            }
            
        elif clothing_type == ClothingType.PANTS:
            # For pants
            left_hip = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_HIP]
            right_hip = landmarks.landmark[self.mp_pose.PoseLandmark.RIGHT_HIP]
            left_ankle = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_ANKLE]
            
            hip_width = abs(right_hip.x - left_hip.x) * width
            leg_height = abs(left_hip.y - left_ankle.y) * height
            
            return {
                'top_left': (
                    int(min(left_hip.x, right_hip.x) * width),
                    int(left_hip.y * height)
                ),
                'width': int(hip_width * 1.5),
                'height': int(leg_height * 1.05)
            }
            
        elif clothing_type == ClothingType.DRESS:
            # For dresses
            left_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_SHOULDER]
            right_shoulder = landmarks.landmark[self.mp_pose.PoseLandmark.RIGHT_SHOULDER]
            left_knee = landmarks.landmark[self.mp_pose.PoseLandmark.LEFT_KNEE]
            
            shoulder_width = abs(right_shoulder.x - left_shoulder.x) * width
            dress_height = abs(left_shoulder.y - left_knee.y) * height
            
            return {
                'top_left': (
                    int(min(left_shoulder.x, right_shoulder.x) * width),
                    int(left_shoulder.y * height)
                ),
                'width': int(shoulder_width * 1.4),
                'height': int(dress_height * 1.1)
            }
    
    def try_on(self, person_image: np.ndarray, clothing_image: np.ndarray, 
               clothing_type: ClothingType) -> np.ndarray:
        """
        Enhanced try-on method with support for different clothing types
        """
        # Get body segmentation
        body_mask = self.get_body_segmentation(person_image)
        
        # Get pose landmarks
        results = self.pose.process(cv2.cvtColor(person_image, cv2.COLOR_BGR2RGB))
        if not results.pose_landmarks:
            raise ValueError("No person detected in the image")
        
        # Estimate lighting conditions
        lighting_params = self.estimate_lighting(person_image)
        
        # Adjust clothing colors
        adjusted_clothing = self.adjust_clothing_color(clothing_image, lighting_params)
        
        # Get clothing dimensions
        dimensions = self.get_clothing_dimensions(
            results.pose_landmarks, 
            person_image.shape, 
            clothing_type
        )
        
        # Resize clothing
        clothing_resized = cv2.resize(
            adjusted_clothing,
            (dimensions['width'], dimensions['height']),
            interpolation=cv2.INTER_AREA
        )
        
        # Create alpha mask for smooth blending
        if clothing_resized.shape[2] == 4:
            alpha_channel = clothing_resized[:, :, 3] / 255.0
        else:
            alpha_channel = np.ones(clothing_resized.shape[:2])
        
        alpha_3channel = np.stack([alpha_channel] * 3, axis=2)
        
        # Calculate placement coordinates
        y1 = dimensions['top_left'][1]
        y2 = y1 + dimensions['height']
        x1 = dimensions['top_left'][0]
        x2 = x1 + dimensions['width']
        
        # Ensure coordinates are within image boundaries
        y1 = max(0, y1)
        y2 = min(person_image.shape[0], y2)
        x1 = max(0, x1)
        x2 = min(person_image.shape[1], x2)
        
        # Apply body mask to improve blending
        body_mask_roi = body_mask[y1:y2, x1:x2]
        alpha_3channel = alpha_3channel * np.expand_dims(body_mask_roi, axis=2)
        
        # Blend images
        roi = person_image[y1:y2, x1:x2]
        clothing_rgb = clothing_resized[:, :, :3]
        blended = (1 - alpha_3channel) * roi + alpha_3channel * clothing_rgb[:roi.shape[0], :roi.shape[1]]
        
        result = person_image.copy()
        result[y1:y2, x1:x2] = blended
        
        return result

def create_gradio_interface():
    def process_images(person_img, clothing_img, clothing_type):
        try_on = VirtualTryOn()
        
        # Convert clothing type string to enum
        clothing_type_enum = ClothingType(clothing_type.lower())
        
        # Process the images
        result = try_on.try_on(person_img, clothing_img, clothing_type_enum)
        
        return result
    
    # Create the interface
    iface = gr.Interface(
        fn=process_images,
        inputs=[
            gr.Image(label="Upload Person Image"),
            gr.Image(label="Upload Clothing Image"),
            gr.Dropdown(
                choices=["Shirt", "Pants", "Dress", "Jacket"],
                label="Select Clothing Type"
            )
        ],
        outputs=gr.Image(label="Result"),
        title="Virtual Try-On System",
        description="Upload a person's image and a clothing item to see how it looks!",
        examples=[
            ["person.jpg", "shirt.png", "Shirt"],
            ["person.jpg", "pants.png", "Pants"]
        ]
    )
    
    return iface

if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()