gaur3009 commited on
Commit
198f320
·
verified ·
1 Parent(s): 0316ce2

Update warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +46 -58
warp_design_on_dress.py CHANGED
@@ -3,13 +3,15 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
- from networks import GMM, TOM, load_checkpoint, Options # Updated imports
 
7
  from preprocessing import pad_to_22_channels
8
 
9
  def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
10
  os.makedirs(output_dir, exist_ok=True)
11
-
12
- # Preprocessing with enhanced normalization
 
13
  im_h, im_w = 256, 192
14
  tf = transforms.Compose([
15
  transforms.Resize((im_h, im_w)),
@@ -17,39 +19,28 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
17
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
18
  ])
19
 
20
- # Load and prepare images with error handling
21
- try:
22
- dress_img = Image.open(dress_path).convert("RGB")
23
- design_img = Image.open(design_path).convert("RGB")
24
- except Exception as e:
25
- raise ValueError(f"Error loading images: {str(e)}")
26
-
27
- # Convert to tensors
28
- dress_tensor = tf(dress_img).unsqueeze(0).cpu()
29
- design_tensor = tf(design_img).unsqueeze(0).cpu()
30
- design_mask = torch.ones_like(design_tensor[:, :1, :, :])
31
-
32
- # Prepare agnostic (dress image)
33
- agnostic = dress_tensor.clone()
34
 
35
- # Initialize models with proper device handling
36
- opt = Options()
37
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
 
 
 
 
 
 
 
39
  # GMM Processing
 
40
  gmm = GMM(opt).to(device)
41
  load_checkpoint(gmm, gmm_ckpt, strict=False)
42
  gmm.eval()
43
 
44
- # Convert to required channels and move to device
45
- agnostic_22ch = pad_to_22_channels(agnostic).contiguous().to(device)
46
- design_mask_22ch = pad_to_22_channels(design_mask).contiguous().to(device)
47
- design_tensor = design_tensor.to(device)
48
- design_mask = design_mask.to(device)
49
-
50
  with torch.no_grad():
51
- # Process through GMM with align_corners
52
- grid, _ = gmm(agnostic_22ch, design_mask_22ch)
53
  warped_design = F.grid_sample(
54
  design_tensor,
55
  grid,
@@ -64,43 +55,40 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
64
  )
65
 
66
  # TOM Processing
67
- tom = TOM(opt).to(device) # Using the new TOM class
68
  load_checkpoint(tom, tom_ckpt, strict=False)
69
  tom.eval()
70
 
71
  with torch.no_grad():
72
- # Prepare proper 26-channel input
73
- # Generate additional features (replace with actual feature extraction if available)
74
- gray = agnostic.mean(dim=1, keepdim=True)
75
- edges_x = torch.abs(F.conv2d(gray,
76
- torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=device).float()))
77
- edges_y = torch.abs(F.conv2d(gray,
78
- torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=device).float()))
79
 
80
- # Combine all features (3+3+1+19=26)
81
  tom_input = torch.cat([
82
- agnostic, # 3 channels
83
- warped_design, # 3 channels
84
- warped_mask, # 1 channel
85
- gray, # 1 channel
86
- edges_x, # 1 channel
87
- edges_y, # 1 channel
88
- torch.zeros_like(agnostic)[:, :16] # 16 dummy channels (replace with real features)
89
- ], dim=1)
90
 
91
- # Process through TOM
92
  p_rendered, m_composite = tom(tom_input)
93
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
94
 
95
- # Save output with proper denormalization
96
- tryon = tryon.clamp(-1, 1) # Ensure valid range
97
- out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5
98
- out_img = out_img.clip(0, 255).astype("uint8")
99
-
100
- try:
101
- out_pil = Image.fromarray(out_img)
102
- output_path = os.path.join(output_dir, "tryon.jpg")
103
- out_pil.save(output_path)
104
- return output_path
105
- except Exception as e:
106
- raise ValueError(f"Error saving output image: {str(e)}")
 
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
+ import numpy as np
7
+ from networks import GMM, TOM, load_checkpoint, Options
8
  from preprocessing import pad_to_22_channels
9
 
10
  def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
11
  os.makedirs(output_dir, exist_ok=True)
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # Preprocessing
15
  im_h, im_w = 256, 192
16
  tf = transforms.Compose([
17
  transforms.Resize((im_h, im_w)),
 
19
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
20
  ])
21
 
22
+ # Load images
23
+ dress_img = Image.open(dress_path).convert("RGB")
24
+ design_img = Image.open(design_path).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Convert to tensors
27
+ dress_tensor = tf(dress_img).unsqueeze(0).to(device)
28
+ design_tensor = tf(design_img).unsqueeze(0).to(device)
29
 
30
+ # Create design mask (assuming design covers entire area)
31
+ design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
32
+
33
+ # Prepare agnostic input (dress image)
34
+ agnostic = pad_to_22_channels(dress_tensor).to(device)
35
+
36
  # GMM Processing
37
+ opt = Options()
38
  gmm = GMM(opt).to(device)
39
  load_checkpoint(gmm, gmm_ckpt, strict=False)
40
  gmm.eval()
41
 
 
 
 
 
 
 
42
  with torch.no_grad():
43
+ grid, _ = gmm(agnostic, design_mask)
 
44
  warped_design = F.grid_sample(
45
  design_tensor,
46
  grid,
 
55
  )
56
 
57
  # TOM Processing
58
+ tom = TOM(opt).to(device)
59
  load_checkpoint(tom, tom_ckpt, strict=False)
60
  tom.eval()
61
 
62
  with torch.no_grad():
63
+ # Create simplified feature inputs
64
+ gray = dress_tensor.mean(dim=1, keepdim=True)
65
+ edges = torch.abs(F.conv2d(
66
+ gray,
67
+ torch.tensor([[[[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]], dtype=torch.float32, device=device),
68
+ padding=1
69
+ ))
70
 
71
+ # Combine inputs (agnostic + warped design + warped mask + features)
72
  tom_input = torch.cat([
73
+ dress_tensor, # 3 channels
74
+ warped_design, # 3 channels
75
+ warped_mask, # 1 channel
76
+ gray, # 1 channel
77
+ edges, # 1 channel
78
+ torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
79
+ ], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
 
80
 
81
+ # Generate try-on result
82
  p_rendered, m_composite = tom(tom_input)
83
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
84
 
85
+ # Convert to PIL image
86
+ tryon = tryon.squeeze().detach().cpu()
87
+ tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5
88
+ tryon = np.clip(tryon, 0, 255).astype("uint8")
89
+ out_pil = Image.fromarray(tryon)
90
+
91
+ # Save output
92
+ output_path = os.path.join(output_dir, "tryon.jpg")
93
+ out_pil.save(output_path)
94
+ return output_path