Spaces:
Sleeping
Sleeping
File size: 1,905 Bytes
dfdcd97 a3ee867 e9cd6fd c95f3e0 9a34a8b 0c00155 e0d4d2f c95f3e0 0c00155 3cd1243 9a34a8b e0d4d2f 0c00155 4f39124 0c00155 4f39124 0c00155 4f39124 72f4c5c 4f39124 73989e5 26c0f04 0c00155 2dd8fe8 9a34a8b 0c00155 73989e5 0c00155 3ba1061 73989e5 e0d4d2f e9cd6fd 4f39124 e9cd6fd 4f39124 e9cd6fd 3ba1061 9a34a8b 3ba1061 0c00155 e0d4d2f e9cd6fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load SAM 2 model
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def plot_masks(image, masks):
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
for mask in masks:
masked = np.ma.masked_where(mask == 0, mask)
ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
ax.axis('off')
plt.close()
return fig2img(fig)
def segment_everything(input_image):
try:
if input_image is None:
return None, "Please upload an image before submitting."
input_image = Image.fromarray(input_image).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(input_image)
# Use 'everything' prompt
masks, _, _ = predictor.predict([])
# Plot the results
result_image = plot_masks(input_image, masks)
return result_image, f"Segmented everything in the image. Found {len(masks)} objects."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_everything,
inputs=[
gr.Image(type="numpy", label="Upload an image")
],
outputs=[
gr.Image(type="pil", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="SAM 2 Everything Segmentation",
description="Upload an image to segment all objects using SAM 2."
)
# Launch the interface
iface.launch() |