orhir commited on
Commit
83f87d6
·
verified ·
1 Parent(s): b6de01d

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +22 -7
gradio_utils/utils.py CHANGED
@@ -110,6 +110,7 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
110
 
111
  @spaces.GPU(duration=30)
112
  def estimate(model, data):
 
113
  with torch.no_grad():
114
  return model(data)
115
 
@@ -166,17 +167,17 @@ def process(query_img, state,
166
  torch.tensor(target_weight_s).float()[None])
167
 
168
  data = {
169
- 'img_s': [support_img.to(device)],
170
- 'img_q': q_img.to(device),
171
- 'target_s': [target_s.to(device)],
172
- 'target_weight_s': [target_weight_s.to(device)],
173
  'target_q': None,
174
  'target_weight_q': None,
175
  'return_loss': False,
176
  'img_metas': [{'sample_skeleton': [state['skeleton']],
177
  'query_skeleton': state['skeleton'],
178
- 'sample_joints_3d': [kp_src_3d.to(device)],
179
- 'query_joints_3d': kp_src_3d.to(device),
180
  'sample_center': [kp_src_tensor.mean(dim=0)],
181
  'query_center': kp_src_tensor.mean(dim=0),
182
  'sample_scale': [
@@ -199,10 +200,24 @@ def process(query_img, state,
199
  if fp16_cfg is not None:
200
  wrap_fp16_model(model)
201
  load_checkpoint(model, checkpoint_path, map_location='cpu')
202
- model.eval().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  str_data = json.dumps(data, cls=CustomEncoder)
205
  outputs = estimate(model, str_data)
 
206
  # visualize results
207
  vis_s_weight = target_weight_s[0]
208
  vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
 
110
 
111
  @spaces.GPU(duration=30)
112
  def estimate(model, data):
113
+ model.cuda()
114
  with torch.no_grad():
115
  return model(data)
116
 
 
167
  torch.tensor(target_weight_s).float()[None])
168
 
169
  data = {
170
+ 'img_s': [support_img],
171
+ 'img_q': q_img,
172
+ 'target_s': [target_s],
173
+ 'target_weight_s': [target_weight_s],
174
  'target_q': None,
175
  'target_weight_q': None,
176
  'return_loss': False,
177
  'img_metas': [{'sample_skeleton': [state['skeleton']],
178
  'query_skeleton': state['skeleton'],
179
+ 'sample_joints_3d': [kp_src_3d],
180
+ 'query_joints_3d': kp_src_3d,
181
  'sample_center': [kp_src_tensor.mean(dim=0)],
182
  'query_center': kp_src_tensor.mean(dim=0),
183
  'sample_scale': [
 
200
  if fp16_cfg is not None:
201
  wrap_fp16_model(model)
202
  load_checkpoint(model, checkpoint_path, map_location='cpu')
203
+
204
+ data["img_s"] = data["img_s"][0].cpu().numpy().tolist()
205
+ data["img_q"] = data["img_q"].cpu().numpy().tolist()
206
+ data['target_weight_s'][0] = data['target_weight_s'][0].cpu().numpy().tolist()
207
+ data['target_s'][0] = data['target_s'][0].cpu().numpy().tolist()
208
+
209
+ data['img_metas'][0]['sample_joints_3d'][0] = data['img_metas'][0]['sample_joints_3d'][0].cpu().tolist()
210
+ data['img_metas'][0]['query_joints_3d'] = data['img_metas'][0]['query_joints_3d'].cpu().tolist()
211
+ data['img_metas'][0]['sample_center'][0] = data['img_metas'][0]['sample_center'][0].cpu().tolist()
212
+ data['img_metas'][0]['query_center'] = data['img_metas'][0]['query_center'].cpu().tolist()
213
+ data['img_metas'][0]['sample_scale'][0] = data['img_metas'][0]['sample_scale'][0].cpu().tolist()
214
+ data['img_metas'][0]['query_scale'] = data['img_metas'][0]['query_scale'].cpu().tolist()
215
+
216
+ model.eval()
217
 
218
  str_data = json.dumps(data, cls=CustomEncoder)
219
  outputs = estimate(model, str_data)
220
+
221
  # visualize results
222
  vis_s_weight = target_weight_s[0]
223
  vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)