orhir commited on
Commit
b6de01d
·
verified ·
1 Parent(s): 8fda4d0

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +19 -3
gradio_utils/utils.py CHANGED
@@ -106,7 +106,22 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
106
  plt.axis('off') # command for hiding the axis.
107
  return plt
108
 
109
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def process(query_img, state,
111
  cfg_path='configs/test/1shot_split1.py',
112
  checkpoint_path='ckpt/1shot_split1.pth'):
@@ -185,8 +200,9 @@ def process(query_img, state,
185
  wrap_fp16_model(model)
186
  load_checkpoint(model, checkpoint_path, map_location='cpu')
187
  model.eval().to(device)
188
- with torch.no_grad():
189
- outputs = model(**data)
 
190
  # visualize results
191
  vis_s_weight = target_weight_s[0]
192
  vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
 
106
  plt.axis('off') # command for hiding the axis.
107
  return plt
108
 
109
+
110
+
111
+ @spaces.GPU(duration=30)
112
+ def estimate(model, data):
113
+ with torch.no_grad():
114
+ return model(data)
115
+
116
+
117
+ # Custom JSON encoder to handle non-serializable objects
118
+ class CustomEncoder(json.JSONEncoder):
119
+ def default(self, obj):
120
+ if isinstance(obj, np.ndarray):
121
+ return obj.tolist()
122
+ return super().default(obj)
123
+
124
+
125
  def process(query_img, state,
126
  cfg_path='configs/test/1shot_split1.py',
127
  checkpoint_path='ckpt/1shot_split1.pth'):
 
200
  wrap_fp16_model(model)
201
  load_checkpoint(model, checkpoint_path, map_location='cpu')
202
  model.eval().to(device)
203
+
204
+ str_data = json.dumps(data, cls=CustomEncoder)
205
+ outputs = estimate(model, str_data)
206
  # visualize results
207
  vis_s_weight = target_weight_s[0]
208
  vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)