Spaces:
Sleeping
Sleeping
Update warp_design_on_dress.py
Browse files- warp_design_on_dress.py +16 -12
warp_design_on_dress.py
CHANGED
@@ -27,10 +27,10 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
27 |
dress_tensor = tf(dress_img).unsqueeze(0).to(device)
|
28 |
design_tensor = tf(design_img).unsqueeze(0).to(device)
|
29 |
|
30 |
-
# Create design mask
|
31 |
design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
|
32 |
|
33 |
-
# Prepare agnostic input
|
34 |
agnostic = pad_to_22_channels(dress_tensor).to(device)
|
35 |
|
36 |
# GMM Processing
|
@@ -62,19 +62,23 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
62 |
with torch.no_grad():
|
63 |
# Create simplified feature inputs
|
64 |
gray = dress_tensor.mean(dim=1, keepdim=True)
|
65 |
-
edges = torch.abs(F.conv2d(
|
66 |
-
gray,
|
67 |
-
torch.tensor([[[[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]], dtype=torch.float32, device=device),
|
68 |
-
padding=1
|
69 |
-
))
|
70 |
|
71 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
tom_input = torch.cat([
|
73 |
dress_tensor, # 3 channels
|
74 |
-
warped_design,
|
75 |
-
warped_mask,
|
76 |
-
gray,
|
77 |
-
edges,
|
78 |
torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
|
79 |
], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
|
80 |
|
|
|
27 |
dress_tensor = tf(dress_img).unsqueeze(0).to(device)
|
28 |
design_tensor = tf(design_img).unsqueeze(0).to(device)
|
29 |
|
30 |
+
# Create design mask
|
31 |
design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
|
32 |
|
33 |
+
# Prepare agnostic input
|
34 |
agnostic = pad_to_22_channels(dress_tensor).to(device)
|
35 |
|
36 |
# GMM Processing
|
|
|
62 |
with torch.no_grad():
|
63 |
# Create simplified feature inputs
|
64 |
gray = dress_tensor.mean(dim=1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# Create edge detection kernel
|
67 |
+
kernel = torch.tensor([
|
68 |
+
[[[-1, -1, -1],
|
69 |
+
[-1, 8, -1],
|
70 |
+
[-1, -1, -1]]], dtype=torch.float32, device=device)
|
71 |
+
|
72 |
+
# Calculate edges
|
73 |
+
edges = torch.abs(F.conv2d(gray, kernel, padding=1))
|
74 |
+
|
75 |
+
# Combine inputs
|
76 |
tom_input = torch.cat([
|
77 |
dress_tensor, # 3 channels
|
78 |
+
warped_design, # 3 channels
|
79 |
+
warped_mask, # 1 channel
|
80 |
+
gray, # 1 channel
|
81 |
+
edges, # 1 channel
|
82 |
torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
|
83 |
], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
|
84 |
|