Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_utils/utils.py
Browse files- gradio_utils/utils.py +19 -3
gradio_utils/utils.py
CHANGED
@@ -106,7 +106,22 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
|
|
106 |
plt.axis('off') # command for hiding the axis.
|
107 |
return plt
|
108 |
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def process(query_img, state,
|
111 |
cfg_path='configs/test/1shot_split1.py',
|
112 |
checkpoint_path='ckpt/1shot_split1.pth'):
|
@@ -185,8 +200,9 @@ def process(query_img, state,
|
|
185 |
wrap_fp16_model(model)
|
186 |
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
187 |
model.eval().to(device)
|
188 |
-
|
189 |
-
|
|
|
190 |
# visualize results
|
191 |
vis_s_weight = target_weight_s[0]
|
192 |
vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
|
|
|
106 |
plt.axis('off') # command for hiding the axis.
|
107 |
return plt
|
108 |
|
109 |
+
|
110 |
+
|
111 |
+
@spaces.GPU(duration=30)
|
112 |
+
def estimate(model, data):
|
113 |
+
with torch.no_grad():
|
114 |
+
return model(data)
|
115 |
+
|
116 |
+
|
117 |
+
# Custom JSON encoder to handle non-serializable objects
|
118 |
+
class CustomEncoder(json.JSONEncoder):
|
119 |
+
def default(self, obj):
|
120 |
+
if isinstance(obj, np.ndarray):
|
121 |
+
return obj.tolist()
|
122 |
+
return super().default(obj)
|
123 |
+
|
124 |
+
|
125 |
def process(query_img, state,
|
126 |
cfg_path='configs/test/1shot_split1.py',
|
127 |
checkpoint_path='ckpt/1shot_split1.pth'):
|
|
|
200 |
wrap_fp16_model(model)
|
201 |
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
202 |
model.eval().to(device)
|
203 |
+
|
204 |
+
str_data = json.dumps(data, cls=CustomEncoder)
|
205 |
+
outputs = estimate(model, str_data)
|
206 |
# visualize results
|
207 |
vis_s_weight = target_weight_s[0]
|
208 |
vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
|