gaur3009 commited on
Commit
a4a6754
·
verified ·
1 Parent(s): 3ecccc5

Update warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +72 -33
warp_design_on_dress.py CHANGED
@@ -3,65 +3,104 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
- from networks import GMM, UnetGenerator, load_checkpoint, Options
7
  from preprocessing import pad_to_22_channels
8
 
9
  def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
10
  os.makedirs(output_dir, exist_ok=True)
11
 
12
- # Preprocessing
13
  im_h, im_w = 256, 192
14
  tf = transforms.Compose([
15
  transforms.Resize((im_h, im_w)),
16
  transforms.ToTensor(),
17
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Added normalization
18
  ])
19
 
20
- dress_img = Image.open(dress_path).convert("RGB")
21
- design_img = Image.open(design_path).convert("RGB")
 
 
 
 
22
 
 
23
  dress_tensor = tf(dress_img).unsqueeze(0).cpu()
24
  design_tensor = tf(design_img).unsqueeze(0).cpu()
25
- design_mask = torch.ones_like(design_tensor[:, :1, :, :]) # full white mask
26
 
27
- # Fake agnostic: use the dress image itself
28
  agnostic = dress_tensor.clone()
29
 
 
30
  opt = Options()
31
- gmm = GMM(opt)
 
 
 
32
  load_checkpoint(gmm, gmm_ckpt, strict=False)
33
- gmm.cpu().eval()
34
 
35
- # Convert agnostic to 22 channels before passing to GMM
36
- agnostic_22ch = pad_to_22_channels(agnostic).contiguous()
37
- design_mask_22ch = pad_to_22_channels(design_mask).contiguous()
 
 
38
 
39
  with torch.no_grad():
40
- grid, _ = gmm(agnostic_22ch, design_mask_22ch) # Use padded inputs
41
- warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
42
- warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # ----- TOM -----
45
- tom = UnetGenerator(26, 4, 6, ngf=64, norm_layer=torch.nn.InstanceNorm2d)
46
- load_checkpoint(tom, tom_ckpt)
47
- tom.cpu().eval()
48
 
49
  with torch.no_grad():
50
- tom_input = torch.cat([agnostic, warped_design, warped_mask], 1)
51
- output = tom(tom_input)
52
-
53
- p_rendered, m_composite = torch.split(output, 3, 1)
54
- p_rendered = torch.tanh(p_rendered)
55
- m_composite = torch.sigmoid(m_composite)
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
57
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
58
 
59
- # Save output
60
- out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5 # Denormalize
 
61
  out_img = out_img.clip(0, 255).astype("uint8")
62
- out_pil = Image.fromarray(out_img)
63
-
64
- output_path = os.path.join(output_dir, "tryon.jpg")
65
- out_pil.save(output_path)
66
-
67
- return output_path
 
 
 
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
+ from networks import GMM, TOM, load_checkpoint, Options # Updated imports
7
  from preprocessing import pad_to_22_channels
8
 
9
  def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
10
  os.makedirs(output_dir, exist_ok=True)
11
 
12
+ # Preprocessing with enhanced normalization
13
  im_h, im_w = 256, 192
14
  tf = transforms.Compose([
15
  transforms.Resize((im_h, im_w)),
16
  transforms.ToTensor(),
17
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
18
  ])
19
 
20
+ # Load and prepare images with error handling
21
+ try:
22
+ dress_img = Image.open(dress_path).convert("RGB")
23
+ design_img = Image.open(design_path).convert("RGB")
24
+ except Exception as e:
25
+ raise ValueError(f"Error loading images: {str(e)}")
26
 
27
+ # Convert to tensors
28
  dress_tensor = tf(dress_img).unsqueeze(0).cpu()
29
  design_tensor = tf(design_img).unsqueeze(0).cpu()
30
+ design_mask = torch.ones_like(design_tensor[:, :1, :, :])
31
 
32
+ # Prepare agnostic (dress image)
33
  agnostic = dress_tensor.clone()
34
 
35
+ # Initialize models with proper device handling
36
  opt = Options()
37
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+
39
+ # GMM Processing
40
+ gmm = GMM(opt).to(device)
41
  load_checkpoint(gmm, gmm_ckpt, strict=False)
42
+ gmm.eval()
43
 
44
+ # Convert to required channels and move to device
45
+ agnostic_22ch = pad_to_22_channels(agnostic).contiguous().to(device)
46
+ design_mask_22ch = pad_to_22_channels(design_mask).contiguous().to(device)
47
+ design_tensor = design_tensor.to(device)
48
+ design_mask = design_mask.to(device)
49
 
50
  with torch.no_grad():
51
+ # Process through GMM with align_corners
52
+ grid, _ = gmm(agnostic_22ch, design_mask_22ch)
53
+ warped_design = F.grid_sample(
54
+ design_tensor,
55
+ grid,
56
+ padding_mode='border',
57
+ align_corners=True
58
+ )
59
+ warped_mask = F.grid_sample(
60
+ design_mask,
61
+ grid,
62
+ padding_mode='zeros',
63
+ align_corners=True
64
+ )
65
 
66
+ # TOM Processing
67
+ tom = TOM(opt).to(device) # Using the new TOM class
68
+ load_checkpoint(tom, tom_ckpt, strict=False)
69
+ tom.eval()
70
 
71
  with torch.no_grad():
72
+ # Prepare proper 26-channel input
73
+ # Generate additional features (replace with actual feature extraction if available)
74
+ gray = agnostic.mean(dim=1, keepdim=True)
75
+ edges_x = torch.abs(F.conv2d(gray,
76
+ torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=device).float()))
77
+ edges_y = torch.abs(F.conv2d(gray,
78
+ torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=device).float()))
79
+
80
+ # Combine all features (3+3+1+19=26)
81
+ tom_input = torch.cat([
82
+ agnostic, # 3 channels
83
+ warped_design, # 3 channels
84
+ warped_mask, # 1 channel
85
+ gray, # 1 channel
86
+ edges_x, # 1 channel
87
+ edges_y, # 1 channel
88
+ torch.zeros_like(agnostic)[:, :16] # 16 dummy channels (replace with real features)
89
+ ], dim=1)
90
 
91
+ # Process through TOM
92
+ p_rendered, m_composite = tom(tom_input)
93
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
94
 
95
+ # Save output with proper denormalization
96
+ tryon = tryon.clamp(-1, 1) # Ensure valid range
97
+ out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5
98
  out_img = out_img.clip(0, 255).astype("uint8")
99
+
100
+ try:
101
+ out_pil = Image.fromarray(out_img)
102
+ output_path = os.path.join(output_dir, "tryon.jpg")
103
+ out_pil.save(output_path)
104
+ return output_path
105
+ except Exception as e:
106
+ raise ValueError(f"Error saving output image: {str(e)}")