gaur3009 commited on
Commit
b61f3f8
·
verified ·
1 Parent(s): f9751b6

Create warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +62 -0
warp_design_on_dress.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from networks import GMM, UnetGenerator, load_checkpoint
7
+
8
+
9
+ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
10
+ os.makedirs(output_dir, exist_ok=True)
11
+
12
+ # Preprocessing
13
+ im_h, im_w = 256, 192
14
+ tf = transforms.Compose([
15
+ transforms.Resize((im_h, im_w)),
16
+ transforms.ToTensor()
17
+ ])
18
+
19
+ dress_img = Image.open(dress_path).convert("RGB")
20
+ design_img = Image.open(design_path).convert("RGB")
21
+
22
+ dress_tensor = tf(dress_img).unsqueeze(0).cuda()
23
+ design_tensor = tf(design_img).unsqueeze(0).cuda()
24
+ design_mask = torch.ones_like(design_tensor[:, :1, :, :]) # full white mask
25
+
26
+ # Fake agnostic: use the dress image itself
27
+ agnostic = dress_tensor.clone()
28
+
29
+ # ----- GMM -----
30
+ gmm = GMM(opt=None)
31
+ load_checkpoint(gmm, gmm_ckpt)
32
+ gmm.cuda().eval()
33
+
34
+ with torch.no_grad():
35
+ grid, _ = gmm(agnostic, design_mask)
36
+ warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
37
+ warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
38
+
39
+ # ----- TOM -----
40
+ tom = UnetGenerator(26, 4, 6, ngf=64, norm_layer=torch.nn.InstanceNorm2d)
41
+ load_checkpoint(tom, tom_ckpt)
42
+ tom.cuda().eval()
43
+
44
+ with torch.no_grad():
45
+ tom_input = torch.cat([agnostic, warped_design, warped_mask], 1)
46
+ output = tom(tom_input)
47
+
48
+ p_rendered, m_composite = torch.split(output, 3, 1)
49
+ p_rendered = torch.tanh(p_rendered)
50
+ m_composite = torch.sigmoid(m_composite)
51
+
52
+ tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
53
+
54
+ # Save output
55
+ out_img = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
56
+ out_img = (out_img * 255).astype("uint8")
57
+ out_pil = Image.fromarray(out_img)
58
+
59
+ output_path = os.path.join(output_dir, "tryon.jpg")
60
+ out_pil.save(output_path)
61
+
62
+ return output_path