import types import numpy as np import streamlit as st import torch from distinctipy import distinctipy from segment_anything import (SamAutomaticMaskGenerator, SamPredictor, sam_model_registry) from torch.nn import functional as F def get_color(): return distinctipy.get_colors(200) def medsam_preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - x.min()) / torch.clip( x.max() - x.min(), min=1e-8, max=None) # normalize to [0, 1], (H, W, 3) # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x @st.cache_resource def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = sam_model_registry['vit_b'](checkpoint=checkpoint) # Replace preprocess function funcType = types.MethodType model.preprocess = funcType(medsam_preprocess, model) model.mask_threshold = 0.5 model = model.to(device) if torch.cuda.is_available(): torch.cuda.empty_cache() predictor = SamPredictor(model) mask_generator = SamAutomaticMaskGenerator(model) return predictor, mask_generator def show_everything(sorted_anns): if len(sorted_anns) == 0: return np.array([]) #sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True) h, w = sorted_anns[0]['segmentation'].shape[-2:] #sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)] mask = np.zeros((h,w,4)) for ann in sorted_anns: m = ann['segmentation'] color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) mask += m.reshape(h,w,1) * color.reshape(1, 1, -1) mask = mask * 255 return mask.astype(np.uint8) def show_click(masks, colors): h, w = masks[0].shape[-2:] masks_total = np.zeros((h,w,4)).astype(np.uint8) for mask, color in zip(masks, colors): if np.array_equal(mask,np.array([])):continue masks = np.zeros((h,w,4)).astype(np.uint8) masks = masks + mask.reshape(h,w,1).astype(np.uint8) masks = masks.astype(bool).astype(np.uint8) masks = masks * 255 * color.reshape(1, 1, -1) masks_total += masks.astype(np.uint8) return masks_total def model_predict_masks_click(model,input_points,input_labels): if input_points == []:return np.array([]) input_labels = np.array(input_labels) input_points = np.array(input_points) masks, _, _ = model.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False, ) if torch.cuda.is_available(): torch.cuda.empty_cache() return masks def model_predict_masks_box(model,center_point,center_label,input_box): masks = np.array([]) for i in range(len(center_label)): if center_point[i] == []:continue center_point_1 = np.array([center_point[i]]) center_label_1 = np.array(center_label[i]) input_box_1 = np.array(input_box[i]) mask, _, _ = model.predict( point_coords=center_point_1, point_labels=center_label_1, box=input_box_1, multimask_output=False, ) try: masks = masks + mask except: masks = mask if torch.cuda.is_available(): torch.cuda.empty_cache() return masks def model_predict_masks_everything(mask_generator, image): masks = mask_generator.generate(image) if torch.cuda.is_available(): torch.cuda.empty_cache() return masks