orhir commited on
Commit
12fc9ad
·
verified ·
1 Parent(s): 9cbe26b

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +5 -4
gradio_utils/utils.py CHANGED
@@ -37,6 +37,7 @@ def adj_mx_from_edges(num_pts, skeleton, device='cuda', normalization_fix=True):
37
  return adj
38
 
39
 
 
40
  def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
41
  skeleton=None, prediction=None, radius=6, in_color=None,
42
  original_skeleton=None, img_alpha=0.6, target_keypoints=None):
@@ -126,13 +127,13 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
126
  plt.subplots_adjust(0, 0, 1, 1, 0, 0)
127
  return plt
128
 
129
-
130
  def process(query_img, state,
131
  cfg_path='configs/test/1shot_split1.py',
132
  checkpoint_path='ckpt/1shot_split1.pth'):
133
  cfg = Config.fromfile(cfg_path)
134
- width, height, _ = state['original_support_image'].shape
135
- kp_src_np = np.array(state['kp_src']).copy().astype(np.float32)
136
  kp_src_np[:, 0] = kp_src_np[:, 0] / (width // 4) * cfg.model.encoder_config.img_size
137
  kp_src_np[:, 1] = kp_src_np[:, 1] / (height // 4) * cfg.model.encoder_config.img_size
138
  kp_src_np = np.flip(kp_src_np, 1).copy()
@@ -226,7 +227,7 @@ def process(query_img, state,
226
 
227
  def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02):
228
  state['color_idx'] = 0
229
- state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
230
  support_img, posed_support, _ = set_query(support_img, state, example=True)
231
  w, h = support_img.size
232
  draw_pose = ImageDraw.Draw(support_img)
 
37
  return adj
38
 
39
 
40
+ @spaces.GPU
41
  def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w,
42
  skeleton=None, prediction=None, radius=6, in_color=None,
43
  original_skeleton=None, img_alpha=0.6, target_keypoints=None):
 
127
  plt.subplots_adjust(0, 0, 1, 1, 0, 0)
128
  return plt
129
 
130
+ @spaces.GPU
131
  def process(query_img, state,
132
  cfg_path='configs/test/1shot_split1.py',
133
  checkpoint_path='ckpt/1shot_split1.pth'):
134
  cfg = Config.fromfile(cfg_path)
135
+ width, height, _ = state['images']['image_orig'].shape
136
+ kp_src_np = np.array(state['points']).copy().astype(np.float32)
137
  kp_src_np[:, 0] = kp_src_np[:, 0] / (width // 4) * cfg.model.encoder_config.img_size
138
  kp_src_np[:, 1] = kp_src_np[:, 1] / (height // 4) * cfg.model.encoder_config.img_size
139
  kp_src_np = np.flip(kp_src_np, 1).copy()
 
227
 
228
  def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02):
229
  state['color_idx'] = 0
230
+ state['images']['image_orig'] = np.array(support_img)[:, :, ::-1].copy()
231
  support_img, posed_support, _ = set_query(support_img, state, example=True)
232
  w, h = support_img.size
233
  draw_pose = ImageDraw.Draw(support_img)