gaur3009 commited on
Commit
40d3928
·
verified ·
1 Parent(s): 97b4c52

Update warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +16 -12
warp_design_on_dress.py CHANGED
@@ -27,10 +27,10 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
27
  dress_tensor = tf(dress_img).unsqueeze(0).to(device)
28
  design_tensor = tf(design_img).unsqueeze(0).to(device)
29
 
30
- # Create design mask (assuming design covers entire area)
31
  design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
32
 
33
- # Prepare agnostic input (dress image)
34
  agnostic = pad_to_22_channels(dress_tensor).to(device)
35
 
36
  # GMM Processing
@@ -62,19 +62,23 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
62
  with torch.no_grad():
63
  # Create simplified feature inputs
64
  gray = dress_tensor.mean(dim=1, keepdim=True)
65
- edges = torch.abs(F.conv2d(
66
- gray,
67
- torch.tensor([[[[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]], dtype=torch.float32, device=device),
68
- padding=1
69
- ))
70
 
71
- # Combine inputs (agnostic + warped design + warped mask + features)
 
 
 
 
 
 
 
 
 
72
  tom_input = torch.cat([
73
  dress_tensor, # 3 channels
74
- warped_design, # 3 channels
75
- warped_mask, # 1 channel
76
- gray, # 1 channel
77
- edges, # 1 channel
78
  torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
79
  ], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
80
 
 
27
  dress_tensor = tf(dress_img).unsqueeze(0).to(device)
28
  design_tensor = tf(design_img).unsqueeze(0).to(device)
29
 
30
+ # Create design mask
31
  design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
32
 
33
+ # Prepare agnostic input
34
  agnostic = pad_to_22_channels(dress_tensor).to(device)
35
 
36
  # GMM Processing
 
62
  with torch.no_grad():
63
  # Create simplified feature inputs
64
  gray = dress_tensor.mean(dim=1, keepdim=True)
 
 
 
 
 
65
 
66
+ # Create edge detection kernel
67
+ kernel = torch.tensor([
68
+ [[[-1, -1, -1],
69
+ [-1, 8, -1],
70
+ [-1, -1, -1]]], dtype=torch.float32, device=device)
71
+
72
+ # Calculate edges
73
+ edges = torch.abs(F.conv2d(gray, kernel, padding=1))
74
+
75
+ # Combine inputs
76
  tom_input = torch.cat([
77
  dress_tensor, # 3 channels
78
+ warped_design, # 3 channels
79
+ warped_mask, # 1 channel
80
+ gray, # 1 channel
81
+ edges, # 1 channel
82
  torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
83
  ], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
84