Spaces:
Build error
Build error
Commit
·
2751f79
1
Parent(s):
a0ea772
fixed noise in mask issue
Browse files- app/main.py +3 -1
- app/per_sam/model.py +5 -1
- app/sam/predictor.py +2 -0
- instance-labeler/app/canvas.tsx +3 -3
- requirements.txt +1 -0
app/main.py
CHANGED
@@ -208,7 +208,9 @@ async def label_image(image: str, mask_labels: MaskLabels) -> Response:
|
|
208 |
for i in range(len(mask_labels.masks)):
|
209 |
mask_i = mask_labels.masks[i]
|
210 |
mask_i = mask_i.replace("data:image/png;base64,", "")
|
211 |
-
|
|
|
|
|
212 |
bboxes = (
|
213 |
batched_mask_to_box(
|
214 |
torch.as_tensor(np.array([np.array(m) for m in save_masks]))
|
|
|
208 |
for i in range(len(mask_labels.masks)):
|
209 |
mask_i = mask_labels.masks[i]
|
210 |
mask_i = mask_i.replace("data:image/png;base64,", "")
|
211 |
+
mask_i = Image.open(io.BytesIO(base64.b64decode(mask_i))).convert("L")
|
212 |
+
mask_i = mask_i.point(lambda p: 0 if <= 1 else p)
|
213 |
+
save_masks.append(mask_i)
|
214 |
bboxes = (
|
215 |
batched_mask_to_box(
|
216 |
torch.as_tensor(np.array([np.array(m) for m in save_masks]))
|
app/per_sam/model.py
CHANGED
@@ -8,6 +8,7 @@ import cv2
|
|
8 |
from torchvision.ops.boxes import batched_nms
|
9 |
from app.mobile_sam import SamPredictor
|
10 |
from app.mobile_sam.utils import batched_mask_to_box
|
|
|
11 |
|
12 |
|
13 |
def point_selection(mask_sim, topk: int = 1):
|
@@ -139,7 +140,10 @@ def fast_inference(
|
|
139 |
# Weighted sum three-scale masks
|
140 |
logits_high = logits_high * weights.unsqueeze(-1)
|
141 |
logit_high = logits_high.sum(0)
|
142 |
-
mask = (logit_high > 0).detach().cpu().numpy()
|
|
|
|
|
|
|
143 |
|
144 |
logits = logits * weights_np[..., None]
|
145 |
logit = logits.sum(0)
|
|
|
8 |
from torchvision.ops.boxes import batched_nms
|
9 |
from app.mobile_sam import SamPredictor
|
10 |
from app.mobile_sam.utils import batched_mask_to_box
|
11 |
+
from app.sam.postprocessing import clean_mask_torch
|
12 |
|
13 |
|
14 |
def point_selection(mask_sim, topk: int = 1):
|
|
|
140 |
# Weighted sum three-scale masks
|
141 |
logits_high = logits_high * weights.unsqueeze(-1)
|
142 |
logit_high = logits_high.sum(0)
|
143 |
+
# mask = (logit_high > 0).detach().cpu().numpy()
|
144 |
+
|
145 |
+
mask = (logit_high > 0)
|
146 |
+
mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy()
|
147 |
|
148 |
logits = logits * weights_np[..., None]
|
149 |
logit = logits.sum(0)
|
app/sam/predictor.py
CHANGED
@@ -10,6 +10,7 @@ import torch
|
|
10 |
from typing import Optional, Tuple
|
11 |
from .sam import Sam
|
12 |
from .transforms import ResizeLongestSide
|
|
|
13 |
|
14 |
|
15 |
class SamPredictor:
|
@@ -237,6 +238,7 @@ class SamPredictor:
|
|
237 |
|
238 |
if not return_logits:
|
239 |
masks = masks > self.model.mask_threshold
|
|
|
240 |
|
241 |
return masks, iou_predictions, low_res_masks
|
242 |
|
|
|
10 |
from typing import Optional, Tuple
|
11 |
from .sam import Sam
|
12 |
from .transforms import ResizeLongestSide
|
13 |
+
from .postprocess import clean_mask_torch
|
14 |
|
15 |
|
16 |
class SamPredictor:
|
|
|
238 |
|
239 |
if not return_logits:
|
240 |
masks = masks > self.model.mask_threshold
|
241 |
+
masks = clean_mask_torch(masks.int()).bool()
|
242 |
|
243 |
return masks, iou_predictions, low_res_masks
|
244 |
|
instance-labeler/app/canvas.tsx
CHANGED
@@ -28,7 +28,7 @@ const maskFilter = (imageData: ImageData, color: RGB) => {
|
|
28 |
const g = imageData.data[i + 1];
|
29 |
const b = imageData.data[i + 2];
|
30 |
|
31 |
-
if (r
|
32 |
imageData.data[i + 3] = 0;
|
33 |
} else {
|
34 |
imageData.data[i] = color.r;
|
@@ -261,8 +261,8 @@ export default function Canvas({ imageUrl, imageName }: { imageUrl: string, imag
|
|
261 |
|
262 |
const predLabels = async () => {
|
263 |
const length = useLatest ? groupRef.current.length : groupRef.current.length - 1;
|
264 |
-
if (groupRef.current.length === 0) {
|
265 |
-
alert('Please pin an instance');
|
266 |
return
|
267 |
}
|
268 |
const mask = groupRef.current[length - 1].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height });
|
|
|
28 |
const g = imageData.data[i + 1];
|
29 |
const b = imageData.data[i + 2];
|
30 |
|
31 |
+
if (r <= 1 && g <= 1 && b <= 1) {
|
32 |
imageData.data[i + 3] = 0;
|
33 |
} else {
|
34 |
imageData.data[i] = color.r;
|
|
|
261 |
|
262 |
const predLabels = async () => {
|
263 |
const length = useLatest ? groupRef.current.length : groupRef.current.length - 1;
|
264 |
+
if (groupRef.current.length === 0 || classList.length === 0) {
|
265 |
+
alert('Please pin an instance first');
|
266 |
return
|
267 |
}
|
268 |
const mask = groupRef.current[length - 1].toDataURL({ x: 0, y: 0, width: image?.width, height: image?.height });
|
requirements.txt
CHANGED
@@ -6,3 +6,4 @@ Pillow
|
|
6 |
fastapi
|
7 |
uvicorn
|
8 |
timm
|
|
|
|
6 |
fastapi
|
7 |
uvicorn
|
8 |
timm
|
9 |
+
kornia
|