haritsahm
Fix when no mask result found
4e097eb
raw
history blame
3.8 kB
import types
import numpy as np
import streamlit as st
import torch
from distinctipy import distinctipy
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
sam_model_registry)
from torch.nn import functional as F
def get_color():
return distinctipy.get_colors(200)
def medsam_preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - x.min()) / torch.clip(
x.max() - x.min(), min=1e-8, max=None) # normalize to [0, 1], (H, W, 3)
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
@st.cache_resource
def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = sam_model_registry['vit_b'](checkpoint=checkpoint)
# Replace preprocess function
funcType = types.MethodType
model.preprocess = funcType(medsam_preprocess, model)
model.mask_threshold = 0.5
model = model.to(device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
predictor = SamPredictor(model)
mask_generator = SamAutomaticMaskGenerator(model)
return predictor, mask_generator
def show_everything(sorted_anns):
if len(sorted_anns) == 0:
return np.array([])
#sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
h, w = sorted_anns[0]['segmentation'].shape[-2:]
#sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
mask = np.zeros((h,w,4))
for ann in sorted_anns:
m = ann['segmentation']
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
mask = mask * 255
return mask.astype(np.uint8)
def show_click(masks, colors):
h, w = masks[0].shape[-2:]
masks_total = np.zeros((h,w,4)).astype(np.uint8)
for mask, color in zip(masks, colors):
if np.array_equal(mask,np.array([])):continue
masks = np.zeros((h,w,4)).astype(np.uint8)
masks = masks + mask.reshape(h,w,1).astype(np.uint8)
masks = masks.astype(bool).astype(np.uint8)
masks = masks * 255 * color.reshape(1, 1, -1)
masks_total += masks.astype(np.uint8)
return masks_total
def model_predict_masks_click(model,input_points,input_labels):
if input_points == []:return np.array([])
input_labels = np.array(input_labels)
input_points = np.array(input_points)
masks, _, _ = model.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return masks
def model_predict_masks_box(model,center_point,center_label,input_box):
masks = np.array([])
for i in range(len(center_label)):
if center_point[i] == []:continue
center_point_1 = np.array([center_point[i]])
center_label_1 = np.array(center_label[i])
input_box_1 = np.array(input_box[i])
mask, _, _ = model.predict(
point_coords=center_point_1,
point_labels=center_label_1,
box=input_box_1,
multimask_output=False,
)
try:
masks = masks + mask
except:
masks = mask
if torch.cuda.is_available():
torch.cuda.empty_cache()
return masks
def model_predict_masks_everything(mask_generator, image):
masks = mask_generator.generate(image)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return masks