Spaces:
Sleeping
Sleeping
File size: 3,208 Bytes
dfdcd97 a3ee867 e9cd6fd c95f3e0 e0d4d2f c95f3e0 e0d4d2f 7bee2b4 73989e5 3ba1061 73989e5 e0d4d2f e9cd6fd e0d4d2f e9cd6fd 3ba1061 7bee2b4 e9cd6fd 3ba1061 c95f3e0 7bee2b4 e0d4d2f e9cd6fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
def segment_image(input_image, segment_anything):
try:
if input_image is None:
return None, "Please upload an image before submitting."
# Convert input_image to PIL Image
input_image = Image.fromarray(input_image).convert("RGB")
# Store original size
original_size = input_image.size
if not original_size or 0 in original_size:
return None, "Invalid image size. Please upload a different image."
if segment_anything:
# Segment everything in the image
inputs = processor(input_image, return_tensors="pt").to(device)
else:
# Use the center of the image as a point prompt
width, height = original_size
center_point = [[width // 2, height // 2]]
inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
# Convert mask to numpy array and resize to match original image
if segment_anything:
# Combine all masks
combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
else:
# Use the first mask
combined_mask = masks[0][0].numpy() > 0.5
# Ensure mask is 2D
if combined_mask.ndim > 2:
combined_mask = combined_mask.squeeze()
# Resize mask to match original image size
combined_mask = cv2.resize(combined_mask.astype(np.uint8), (original_size[0], original_size[1])) > 0
# Overlay the mask on the original image
result_image = np.array(input_image)
mask_rgb = np.zeros_like(result_image)
mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
return result_image, "Segmentation completed successfully."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy", label="Upload an image"),
gr.Checkbox(label="Segment Everything")
],
outputs=[
gr.Image(type="numpy", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="Segment Anything Model (SAM) Image Segmentation",
description="Upload an image and choose whether to segment everything or use a center point."
)
# Launch the interface
iface.launch() |