gaur3009 commited on
Commit
6989926
·
verified ·
1 Parent(s): 15e6502

Update warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +61 -72
warp_design_on_dress.py CHANGED
@@ -5,94 +5,83 @@ from torchvision import transforms
5
  from PIL import Image
6
  import numpy as np
7
  from networks import GMM, TOM, load_checkpoint, Options
8
- from preprocessing import pad_to_22_channels
9
 
10
- def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
11
- os.makedirs(output_dir, exist_ok=True)
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
-
14
- # Preprocessing
15
- im_h, im_w = 256, 192
16
- tf = transforms.Compose([
17
- transforms.Resize((im_h, im_w)),
18
  transforms.ToTensor(),
19
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
20
  ])
21
-
22
- # Load images
23
- dress_img = Image.open(dress_path).convert("RGB")
24
- design_img = Image.open(design_path).convert("RGB")
25
 
26
- # Convert to tensors
27
- dress_tensor = tf(dress_img).unsqueeze(0).to(device)
28
- design_tensor = tf(design_img).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Create design mask
31
- design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
32
 
33
- # Prepare agnostic input
34
- agnostic = pad_to_22_channels(dress_tensor).to(device)
 
35
 
36
- # GMM Processing
 
 
 
 
 
 
 
 
37
  opt = Options()
38
  gmm = GMM(opt).to(device)
39
- load_checkpoint(gmm, gmm_ckpt, strict=False)
40
- gmm.eval()
41
-
 
 
 
 
 
 
 
 
 
42
  with torch.no_grad():
 
43
  grid, _ = gmm(agnostic, design_mask)
44
- warped_design = F.grid_sample(
45
- design_tensor,
46
- grid,
47
- padding_mode='border',
48
- align_corners=True
49
- )
50
- warped_mask = F.grid_sample(
51
- design_mask,
52
- grid,
53
- padding_mode='zeros',
54
- align_corners=True
55
- )
56
-
57
  # TOM Processing
58
- tom = TOM(opt).to(device)
59
- load_checkpoint(tom, tom_ckpt, strict=False)
60
- tom.eval()
61
-
62
  with torch.no_grad():
63
- # Create simplified feature inputs
64
- gray = dress_tensor.mean(dim=1, keepdim=True)
65
-
66
- # Create edge detection kernel
67
- kernel = torch.tensor(
68
- [[[-1, -1, -1],
69
- [-1, 8, -1],
70
- [-1, -1, -1]]], dtype=torch.float32, device=device)
71
-
72
- # Calculate edges
73
- edges = torch.abs(F.conv2d(gray, kernel, padding=1))
74
-
75
- # Combine inputs
76
- tom_input = torch.cat([
77
- dress_tensor, # 3 channels
78
- warped_design, # 3 channels
79
- warped_mask, # 1 channel
80
- gray, # 1 channel
81
- edges, # 1 channel
82
- torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
83
- ], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
84
-
85
- # Generate try-on result
86
  p_rendered, m_composite = tom(tom_input)
 
 
87
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
88
-
89
- # Convert to PIL image
90
- tryon = tryon.squeeze().detach().cpu()
91
- tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5
92
- tryon = np.clip(tryon, 0, 255).astype("uint8")
93
- out_pil = Image.fromarray(tryon)
94
 
95
  # Save output
96
- output_path = os.path.join(output_dir, "tryon.jpg")
97
- out_pil.save(output_path)
 
 
 
 
 
98
  return output_path
 
5
  from PIL import Image
6
  import numpy as np
7
  from networks import GMM, TOM, load_checkpoint, Options
8
+ import torchvision.transforms.functional as TF
9
 
10
+ def prepare_inputs(dress_path, design_path, height=256, width=192):
11
+ """Prepare and normalize input images"""
12
+ transform = transforms.Compose([
13
+ transforms.Resize((height, width)),
 
 
 
 
14
  transforms.ToTensor(),
15
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
16
  ])
 
 
 
 
17
 
18
+ dress_img = Image.open(dress_path).convert('RGB')
19
+ design_img = Image.open(design_path).convert('RGB')
20
+
21
+ dress_tensor = transform(dress_img).unsqueeze(0)
22
+ design_tensor = transform(design_img).unsqueeze(0)
23
+
24
+ # Create mask (assume design has transparent background)
25
+ design_arr = np.array(design_img)
26
+ if design_arr.shape[2] == 4: # Has alpha channel
27
+ mask = (design_arr[:, :, 3] > 0).astype(np.float32)
28
+ else:
29
+ mask = np.ones((design_arr.shape[0], design_arr.shape[1]), dtype=np.float32)
30
+
31
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
32
+ mask_tensor = TF.to_tensor(TF.resize(mask_img, (height, width))).unsqueeze(0)
33
 
34
+ return dress_tensor, design_tensor, mask_tensor
 
35
 
36
+ def create_agnostic(dress_tensor):
37
+ """Create agnostic representation of dress"""
38
+ return dress_tensor.clone()
39
 
40
+ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
41
+ os.makedirs(output_dir, exist_ok=True)
42
+
43
+ # Prepare inputs
44
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+ dress_tensor, design_tensor, design_mask = prepare_inputs(dress_path, design_path)
46
+ agnostic = create_agnostic(dress_tensor)
47
+
48
+ # Initialize models
49
  opt = Options()
50
  gmm = GMM(opt).to(device)
51
+ tom = TOM(opt).to(device)
52
+
53
+ # Load checkpoints
54
+ load_checkpoint(gmm, gmm_ckpt)
55
+ load_checkpoint(tom, tom_ckpt)
56
+
57
+ # Move tensors to device
58
+ agnostic = agnostic.to(device)
59
+ design_tensor = design_tensor.to(device)
60
+ design_mask = design_mask.to(device)
61
+
62
+ # GMM Processing
63
  with torch.no_grad():
64
+ gmm.eval()
65
  grid, _ = gmm(agnostic, design_mask)
66
+ warped_design = F.grid_sample(design_tensor, grid, padding_mode='border', align_corners=True)
67
+ warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros', align_corners=True)
68
+
 
 
 
 
 
 
 
 
 
 
69
  # TOM Processing
 
 
 
 
70
  with torch.no_grad():
71
+ tom.eval()
72
+ # Prepare TOM input: [agnostic, warped_design, warped_mask]
73
+ tom_input = torch.cat([agnostic, warped_design, warped_mask], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  p_rendered, m_composite = tom(tom_input)
75
+
76
+ # Final composition
77
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
 
 
 
 
 
 
78
 
79
  # Save output
80
+ tryon = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
81
+ tryon = (tryon * 0.5 + 0.5) * 255
82
+ tryon = tryon.clip(0, 255).astype(np.uint8)
83
+
84
+ output_path = os.path.join(output_dir, "warped_design.jpg")
85
+ Image.fromarray(tryon).save(output_path)
86
+
87
  return output_path