import os import subprocess import sys import requests import zipfile import gradio as gr import torch import numpy as np from torchvision.transforms import ToTensor from PIL import Image import cv2 # Ensure the necessary model files are available def download_file(url, destination): response = requests.get(url, stream=True) with open(destination, 'wb') as f: f.write(response.content) # Download SAM model if not os.path.exists("weights/sam_vit_h_4b8939.pth"): os.makedirs("weights", exist_ok=True) download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "weights/sam_vit_h_4b8939.pth") # Add EfficientSAM to Python path sys.path.append(os.path.abspath("EfficientSAM-main")) # Import SAM and EfficientSAM modules subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"]) subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"]) from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from efficient_sam.build_efficient_sam import build_efficient_sam_vits # Constants DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') MODEL_TYPE = "vit_h" CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth" # Load SAM model sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE) mask_generator_sam = SamAutomaticMaskGenerator(sam) # Load EfficientSAM model with zipfile.ZipFile("EfficientSAM-main/weights/efficient_sam_vits.pt.zip", 'r') as zip_ref: zip_ref.extractall("weights") efficient_sam_vits_model = build_efficient_sam_vits() from segment_anything.utils.amg import ( batched_mask_to_box, calculate_stability_score, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, ) from torchvision.ops.boxes import batched_nms, box_area def process_small_region(rles): new_masks = [] scores = [] min_area = 100 nms_thresh = 0.7 for rle in rles: mask = rle_to_mask(rle[0]) mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) scores.append(float(unchanged)) masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores), torch.zeros_like(boxes[:, 0]), iou_threshold=nms_thresh, ) for i_mask in keep_by_nms: if scores[i_mask] == 0.0: mask_torch = masks[i_mask].unsqueeze(0) rles[i_mask] = mask_to_rle_pytorch(mask_torch) masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms] return masks def get_predictions_given_embeddings_and_queries(img, points, point_labels, model): predicted_masks, predicted_iou = model( img[None, ...], points, point_labels ) sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True) predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2) predicted_masks = torch.take_along_dim( predicted_masks, sorted_ids[..., None, None], dim=2 ) predicted_masks = predicted_masks[0] iou = predicted_iou_scores[0, :, 0] index_iou = iou > 0.7 iou_ = iou[index_iou] masks = predicted_masks[index_iou] score = calculate_stability_score(masks, 0.0, 1.0) score = score[:, 0] index = score > 0.9 score_ = score[index] masks = masks[index] iou_ = iou_[index] masks = torch.ge(masks, 0.0) return masks, iou_ def run_everything_ours(image_np, model): model = model.cpu() image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) img_tensor = ToTensor()(image) _, original_image_h, original_image_w = img_tensor.shape xy = [] GRID_SIZE = 32 for i in range(GRID_SIZE): curr_x = 0.5 + i / GRID_SIZE * original_image_w for j in range(GRID_SIZE): curr_y = 0.5 + j / GRID_SIZE * original_image_h xy.append([curr_x, curr_y]) xy = torch.from_numpy(np.array(xy)) points = xy num_pts = xy.shape[0] point_labels = torch.ones(num_pts, 1) with torch.no_grad(): predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries( img_tensor.cpu(), points.reshape(1, num_pts, 1, 2).cpu(), point_labels.reshape(1, num_pts, 1).cpu(), model.cpu(), ) rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks] predicted_masks = process_small_region(rle) return predicted_masks def show_anns_ours(masks, image): for mask in masks: contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image, contours, -1, (0, 255, 0), 2) return image def process_image(image): # Convert PIL image to numpy array image_np = np.array(image) # Process with SAM image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) sam_result = mask_generator_sam.generate(image_rgb) # Annotate SAM result sam_annotated_image = image_np.copy() for mask in sam_result: sam_annotated_image[mask['segmentation']] = [0, 255, 0] # Process with EfficientSAM mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model) efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy()) return [image, sam_annotated_image, efficient_sam_annotated_image] # Gradio interface interface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")], title="SAM vs EfficientSAM Comparison", description="Upload an image to compare the segmentation results of SAM and EfficientSAM." ) interface.launch(debug=True)