from segment_anything import SamPredictor, sam_model_registry import torch import numpy as np from distinctipy import distinctipy import streamlit as st def get_checkpoint_path(model): return 'checkpoint/medsam_vit_b.pth' def get_color(): return distinctipy.get_colors(200) @st.cache_resource def get_model(model): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") build_sam = sam_model_registry[model] model = build_sam(checkpoint=get_checkpoint_path(model)).to(device) predictor = SamPredictor(model) if torch.cuda.is_available(): torch.cuda.empty_cache() return predictor @st.cache_data def show_everything(sorted_anns): if len(sorted_anns) == 0: return #sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True) h, w = sorted_anns[0]['segmentation'].shape[-2:] #sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)] if sorted_anns == []: return np.zeros((h,w,4)).astype(np.uint8) mask = np.zeros((h,w,4)) for ann in sorted_anns: m = ann['segmentation'] color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) mask += m.reshape(h,w,1) * color.reshape(1, 1, -1) mask = mask * 255 st.success('Process completed!', icon="✅") return mask.astype(np.uint8) def show_click(masks, colors): h, w = masks[0].shape[-2:] masks_total = np.zeros((h,w,4)).astype(np.uint8) for mask, color in zip(masks, colors): if np.array_equal(mask,np.array([])):continue masks = np.zeros((h,w,4)).astype(np.uint8) masks = masks + mask.reshape(h,w,1).astype(np.uint8) masks = masks.astype(bool).astype(np.uint8) masks = masks * 255 * color.reshape(1, 1, -1) masks_total += masks.astype(np.uint8) st.success('Process completed!', icon="✅") return masks_total def model_predict_masks_click(model,input_points,input_labels): if input_points == []:return np.array([]) input_labels = np.array(input_labels) input_points = np.array(input_points) masks, _, _ = model.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False, ) if torch.cuda.is_available(): torch.cuda.empty_cache() return masks