levihsu commited on
Commit
04a4c7d
1 Parent(s): 6869802

Update run/gradio_ootd.py

Browse files
Files changed (1) hide show
  1. run/gradio_ootd.py +13 -3
run/gradio_ootd.py CHANGED
@@ -20,9 +20,9 @@ openpose_model_hd = OpenPose(0)
20
  parsing_model_hd = Parsing(0)
21
  ootd_model_hd = OOTDiffusionHD(0)
22
 
23
- openpose_model_dc = OpenPose(0)
24
- parsing_model_dc = Parsing(0)
25
- ootd_model_dc = OOTDiffusionDC(0)
26
 
27
 
28
  category_dict = ['upperbody', 'lowerbody', 'dress']
@@ -44,6 +44,11 @@ def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
44
  category = 0 # 0:upperbody; 1:lowerbody; 2:dress
45
 
46
  with torch.no_grad():
 
 
 
 
 
47
  garm_img = Image.open(garm_img).resize((768, 1024))
48
  vton_img = Image.open(vton_img).resize((768, 1024))
49
  keypoints = openpose_model_hd(vton_img.resize((384, 512)))
@@ -81,6 +86,11 @@ def process_dc(vton_img, garm_img, category, n_samples, n_steps, image_scale, se
81
  category =2
82
 
83
  with torch.no_grad():
 
 
 
 
 
84
  garm_img = Image.open(garm_img).resize((768, 1024))
85
  vton_img = Image.open(vton_img).resize((768, 1024))
86
  keypoints = openpose_model_dc(vton_img.resize((384, 512)))
 
20
  parsing_model_hd = Parsing(0)
21
  ootd_model_hd = OOTDiffusionHD(0)
22
 
23
+ openpose_model_dc = OpenPose(1)
24
+ parsing_model_dc = Parsing(1)
25
+ ootd_model_dc = OOTDiffusionDC(1)
26
 
27
 
28
  category_dict = ['upperbody', 'lowerbody', 'dress']
 
44
  category = 0 # 0:upperbody; 1:lowerbody; 2:dress
45
 
46
  with torch.no_grad():
47
+ openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
48
+ ootd_model_hd.pipe.to('cuda')
49
+ ootd_model_hd.image_encoder.to('cuda')
50
+ ootd_model_hd.text_encoder.to('cuda')
51
+
52
  garm_img = Image.open(garm_img).resize((768, 1024))
53
  vton_img = Image.open(vton_img).resize((768, 1024))
54
  keypoints = openpose_model_hd(vton_img.resize((384, 512)))
 
86
  category =2
87
 
88
  with torch.no_grad():
89
+ openpose_model_dc.preprocessor.body_estimation.model.to('cuda')
90
+ ootd_model_dc.pipe.to('cuda')
91
+ ootd_model_dc.image_encoder.to('cuda')
92
+ ootd_model_dc.text_encoder.to('cuda')
93
+
94
  garm_img = Image.open(garm_img).resize((768, 1024))
95
  vton_img = Image.open(vton_img).resize((768, 1024))
96
  keypoints = openpose_model_dc(vton_img.resize((384, 512)))