Spaces:
Sleeping
Sleeping
Update warp_design_on_dress.py
Browse files- warp_design_on_dress.py +61 -72
warp_design_on_dress.py
CHANGED
@@ -5,94 +5,83 @@ from torchvision import transforms
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
from networks import GMM, TOM, load_checkpoint, Options
|
8 |
-
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
# Preprocessing
|
15 |
-
im_h, im_w = 256, 192
|
16 |
-
tf = transforms.Compose([
|
17 |
-
transforms.Resize((im_h, im_w)),
|
18 |
transforms.ToTensor(),
|
19 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
20 |
])
|
21 |
-
|
22 |
-
# Load images
|
23 |
-
dress_img = Image.open(dress_path).convert("RGB")
|
24 |
-
design_img = Image.open(design_path).convert("RGB")
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
|
32 |
|
33 |
-
|
34 |
-
agnostic
|
|
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
opt = Options()
|
38 |
gmm = GMM(opt).to(device)
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
with torch.no_grad():
|
|
|
43 |
grid, _ = gmm(agnostic, design_mask)
|
44 |
-
warped_design = F.grid_sample(
|
45 |
-
|
46 |
-
|
47 |
-
padding_mode='border',
|
48 |
-
align_corners=True
|
49 |
-
)
|
50 |
-
warped_mask = F.grid_sample(
|
51 |
-
design_mask,
|
52 |
-
grid,
|
53 |
-
padding_mode='zeros',
|
54 |
-
align_corners=True
|
55 |
-
)
|
56 |
-
|
57 |
# TOM Processing
|
58 |
-
tom = TOM(opt).to(device)
|
59 |
-
load_checkpoint(tom, tom_ckpt, strict=False)
|
60 |
-
tom.eval()
|
61 |
-
|
62 |
with torch.no_grad():
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
# Create edge detection kernel
|
67 |
-
kernel = torch.tensor(
|
68 |
-
[[[-1, -1, -1],
|
69 |
-
[-1, 8, -1],
|
70 |
-
[-1, -1, -1]]], dtype=torch.float32, device=device)
|
71 |
-
|
72 |
-
# Calculate edges
|
73 |
-
edges = torch.abs(F.conv2d(gray, kernel, padding=1))
|
74 |
-
|
75 |
-
# Combine inputs
|
76 |
-
tom_input = torch.cat([
|
77 |
-
dress_tensor, # 3 channels
|
78 |
-
warped_design, # 3 channels
|
79 |
-
warped_mask, # 1 channel
|
80 |
-
gray, # 1 channel
|
81 |
-
edges, # 1 channel
|
82 |
-
torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
|
83 |
-
], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
|
84 |
-
|
85 |
-
# Generate try-on result
|
86 |
p_rendered, m_composite = tom(tom_input)
|
|
|
|
|
87 |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
|
88 |
-
|
89 |
-
# Convert to PIL image
|
90 |
-
tryon = tryon.squeeze().detach().cpu()
|
91 |
-
tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5
|
92 |
-
tryon = np.clip(tryon, 0, 255).astype("uint8")
|
93 |
-
out_pil = Image.fromarray(tryon)
|
94 |
|
95 |
# Save output
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
return output_path
|
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
from networks import GMM, TOM, load_checkpoint, Options
|
8 |
+
import torchvision.transforms.functional as TF
|
9 |
|
10 |
+
def prepare_inputs(dress_path, design_path, height=256, width=192):
|
11 |
+
"""Prepare and normalize input images"""
|
12 |
+
transform = transforms.Compose([
|
13 |
+
transforms.Resize((height, width)),
|
|
|
|
|
|
|
|
|
14 |
transforms.ToTensor(),
|
15 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
16 |
])
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
dress_img = Image.open(dress_path).convert('RGB')
|
19 |
+
design_img = Image.open(design_path).convert('RGB')
|
20 |
+
|
21 |
+
dress_tensor = transform(dress_img).unsqueeze(0)
|
22 |
+
design_tensor = transform(design_img).unsqueeze(0)
|
23 |
+
|
24 |
+
# Create mask (assume design has transparent background)
|
25 |
+
design_arr = np.array(design_img)
|
26 |
+
if design_arr.shape[2] == 4: # Has alpha channel
|
27 |
+
mask = (design_arr[:, :, 3] > 0).astype(np.float32)
|
28 |
+
else:
|
29 |
+
mask = np.ones((design_arr.shape[0], design_arr.shape[1]), dtype=np.float32)
|
30 |
+
|
31 |
+
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
|
32 |
+
mask_tensor = TF.to_tensor(TF.resize(mask_img, (height, width))).unsqueeze(0)
|
33 |
|
34 |
+
return dress_tensor, design_tensor, mask_tensor
|
|
|
35 |
|
36 |
+
def create_agnostic(dress_tensor):
|
37 |
+
"""Create agnostic representation of dress"""
|
38 |
+
return dress_tensor.clone()
|
39 |
|
40 |
+
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
|
41 |
+
os.makedirs(output_dir, exist_ok=True)
|
42 |
+
|
43 |
+
# Prepare inputs
|
44 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
45 |
+
dress_tensor, design_tensor, design_mask = prepare_inputs(dress_path, design_path)
|
46 |
+
agnostic = create_agnostic(dress_tensor)
|
47 |
+
|
48 |
+
# Initialize models
|
49 |
opt = Options()
|
50 |
gmm = GMM(opt).to(device)
|
51 |
+
tom = TOM(opt).to(device)
|
52 |
+
|
53 |
+
# Load checkpoints
|
54 |
+
load_checkpoint(gmm, gmm_ckpt)
|
55 |
+
load_checkpoint(tom, tom_ckpt)
|
56 |
+
|
57 |
+
# Move tensors to device
|
58 |
+
agnostic = agnostic.to(device)
|
59 |
+
design_tensor = design_tensor.to(device)
|
60 |
+
design_mask = design_mask.to(device)
|
61 |
+
|
62 |
+
# GMM Processing
|
63 |
with torch.no_grad():
|
64 |
+
gmm.eval()
|
65 |
grid, _ = gmm(agnostic, design_mask)
|
66 |
+
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border', align_corners=True)
|
67 |
+
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros', align_corners=True)
|
68 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
# TOM Processing
|
|
|
|
|
|
|
|
|
70 |
with torch.no_grad():
|
71 |
+
tom.eval()
|
72 |
+
# Prepare TOM input: [agnostic, warped_design, warped_mask]
|
73 |
+
tom_input = torch.cat([agnostic, warped_design, warped_mask], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
p_rendered, m_composite = tom(tom_input)
|
75 |
+
|
76 |
+
# Final composition
|
77 |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Save output
|
80 |
+
tryon = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
|
81 |
+
tryon = (tryon * 0.5 + 0.5) * 255
|
82 |
+
tryon = tryon.clip(0, 255).astype(np.uint8)
|
83 |
+
|
84 |
+
output_path = os.path.join(output_dir, "warped_design.jpg")
|
85 |
+
Image.fromarray(tryon).save(output_path)
|
86 |
+
|
87 |
return output_path
|