gaur3009 commited on
Commit
08cfba0
·
verified ·
1 Parent(s): 4560e58

Update warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +1 -1
warp_design_on_dress.py CHANGED
@@ -39,7 +39,7 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
39
  # ----- TOM -----
40
  tom = UnetGenerator(26, 4, 6, ngf=64, norm_layer=torch.nn.InstanceNorm2d)
41
  load_checkpoint(tom, tom_ckpt)
42
- tom.cuda().eval()
43
 
44
  with torch.no_grad():
45
  tom_input = torch.cat([agnostic, warped_design, warped_mask], 1)
 
39
  # ----- TOM -----
40
  tom = UnetGenerator(26, 4, 6, ngf=64, norm_layer=torch.nn.InstanceNorm2d)
41
  load_checkpoint(tom, tom_ckpt)
42
+ tom.cpu().eval()
43
 
44
  with torch.no_grad():
45
  tom_input = torch.cat([agnostic, warped_design, warped_mask], 1)