orhir commited on
Commit
3c43b62
·
verified ·
1 Parent(s): 71d8415

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +4 -7
gradio_utils/utils.py CHANGED
@@ -44,7 +44,7 @@ def adj_mx_from_edges(num_pts, skeleton, device='cpu', normalization_fix=True):
44
  adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
45
  return adj
46
 
47
- def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
48
  skeleton=None, prediction=None, radius=6, in_color=None,
49
  original_skeleton=None, img_alpha=0.6, target_keypoints=None):
50
  h, w, c = support_img.shape
@@ -52,7 +52,7 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
52
  if isinstance(prediction, torch.Tensor):
53
  prediction = prediction.numpy()
54
  if isinstance(original_skeleton, list):
55
- original_skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).numpy()[0]
56
  query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
57
  img = query_img
58
  w = query_w
@@ -204,12 +204,9 @@ def process(query_img, state,
204
  support_kp = kp_src_3d
205
  out = plot_results(vis_s_image,
206
  vis_q_image,
207
- support_kp,
208
  vis_s_weight,
209
- None,
210
- vis_s_weight,
211
- outputs['skeleton'][1],
212
- torch.tensor(outputs['points']).squeeze().cpu(),
213
  original_skeleton=state['skeleton'],
214
  img_alpha=1.0,
215
  )
 
44
  adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
45
  return adj
46
 
47
+ def plot_results(support_img, query_img, query_w,
48
  skeleton=None, prediction=None, radius=6, in_color=None,
49
  original_skeleton=None, img_alpha=0.6, target_keypoints=None):
50
  h, w, c = support_img.shape
 
52
  if isinstance(prediction, torch.Tensor):
53
  prediction = prediction.numpy()
54
  if isinstance(original_skeleton, list):
55
+ original_skeleton = adj_mx_from_edges(num_pts=prediction.shape[0], skeleton=[original_skeleton]).numpy()[0]
56
  query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
57
  img = query_img
58
  w = query_w
 
204
  support_kp = kp_src_3d
205
  out = plot_results(vis_s_image,
206
  vis_q_image,
 
207
  vis_s_weight,
208
+ skeleton=outputs['skeleton'][1],
209
+ prediction=torch.tensor(outputs['points']).squeeze().cpu(),
 
 
210
  original_skeleton=state['skeleton'],
211
  img_alpha=1.0,
212
  )