Spaces:
Runtime error
Runtime error
bthndmn12
commited on
Commit
•
87757c1
1
Parent(s):
06a19d8
Fixed some bugs
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ model = SamModel.from_pretrained("./checkpoint",local_files_only=True)
|
|
17 |
def get_bbox(gt_map):
|
18 |
|
19 |
if gt_map.ndim > 2:
|
20 |
-
gt_map = gt_map[:, :, 0]
|
21 |
|
22 |
# Check if the ground truth map is empty
|
23 |
if np.sum(gt_map) == 0:
|
@@ -57,7 +57,7 @@ def process_image(image_input):
|
|
57 |
outputs = model(**inputs, multimask_output=False)
|
58 |
|
59 |
# Process model output
|
60 |
-
seg_prob = torch.sigmoid(outputs
|
61 |
seg_prob = seg_prob.cpu().numpy().squeeze()
|
62 |
seg = (seg_prob > 0.5).astype(np.uint8)
|
63 |
|
|
|
17 |
def get_bbox(gt_map):
|
18 |
|
19 |
if gt_map.ndim > 2:
|
20 |
+
gt_map = gt_map[:, :, 0]
|
21 |
|
22 |
# Check if the ground truth map is empty
|
23 |
if np.sum(gt_map) == 0:
|
|
|
57 |
outputs = model(**inputs, multimask_output=False)
|
58 |
|
59 |
# Process model output
|
60 |
+
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
61 |
seg_prob = seg_prob.cpu().numpy().squeeze()
|
62 |
seg = (seg_prob > 0.5).astype(np.uint8)
|
63 |
|