orhir commited on
Commit
8e246b4
·
verified ·
1 Parent(s): e6f223b

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +6 -9
gradio_utils/utils.py CHANGED
@@ -30,7 +30,7 @@ def process_img(support_image, global_state):
30
  return support_image, global_state
31
 
32
 
33
- def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
34
  adj_mx = torch.empty(0, device=device)
35
  batch_size = len(skeleton)
36
  for b in range(batch_size):
@@ -43,16 +43,15 @@ def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
43
  adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
44
  return adj
45
 
46
- @spaces.GPU(duration=30)
47
  def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
48
  skeleton=None, prediction=None, radius=6, in_color=None,
49
  original_skeleton=None, img_alpha=0.6, target_keypoints=None):
50
  h, w, c = support_img.shape
51
  prediction = prediction[-1] * h
52
  if isinstance(prediction, torch.Tensor):
53
- prediction = prediction.cpu().numpy()
54
  if isinstance(original_skeleton, list):
55
- original_skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).cpu().numpy()[0]
56
  query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
57
  img = query_img
58
  w = query_w
@@ -103,8 +102,8 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
103
  zorder=1)
104
  axes.add_artist(patch)
105
 
106
- plt.axis('off') # command for hiding the axis.
107
- return plt
108
 
109
 
110
 
@@ -200,7 +199,6 @@ def process(query_img, state,
200
  vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
201
  vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0)
202
  support_kp = kp_src_3d
203
- print("Try to Plot")
204
  out = plot_results(vis_s_image,
205
  vis_q_image,
206
  support_kp,
@@ -208,11 +206,10 @@ def process(query_img, state,
208
  None,
209
  vis_s_weight,
210
  outputs['skeleton'][1],
211
- torch.tensor(outputs['points']).squeeze(),
212
  original_skeleton=state['skeleton'],
213
  img_alpha=1.0,
214
  )
215
- print("Plot Success!")
216
  return out
217
 
218
 
 
30
  return support_image, global_state
31
 
32
 
33
+ def adj_mx_from_edges(num_pts, skeleton, device='cpu', normalization_fix=True):
34
  adj_mx = torch.empty(0, device=device)
35
  batch_size = len(skeleton)
36
  for b in range(batch_size):
 
43
  adj = adj_mx + trans_adj_mx * cond - adj_mx * cond
44
  return adj
45
 
 
46
  def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
47
  skeleton=None, prediction=None, radius=6, in_color=None,
48
  original_skeleton=None, img_alpha=0.6, target_keypoints=None):
49
  h, w, c = support_img.shape
50
  prediction = prediction[-1] * h
51
  if isinstance(prediction, torch.Tensor):
52
+ prediction = prediction.numpy()
53
  if isinstance(original_skeleton, list):
54
+ original_skeleton = adj_mx_from_edges(num_pts=100, skeleton=[skeleton]).numpy()[0]
55
  query_img = (query_img - np.min(query_img)) / (np.max(query_img) - np.min(query_img))
56
  img = query_img
57
  w = query_w
 
102
  zorder=1)
103
  axes.add_artist(patch)
104
 
105
+ plt.axis('off') # command for hiding the axis.
106
+ return plt
107
 
108
 
109
 
 
199
  vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
200
  vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0)
201
  support_kp = kp_src_3d
 
202
  out = plot_results(vis_s_image,
203
  vis_q_image,
204
  support_kp,
 
206
  None,
207
  vis_s_weight,
208
  outputs['skeleton'][1],
209
+ torch.tensor(outputs['points']).squeeze().cpu(),
210
  original_skeleton=state['skeleton'],
211
  img_alpha=1.0,
212
  )
 
213
  return out
214
 
215