File size: 1,905 Bytes
dfdcd97
a3ee867
e9cd6fd
c95f3e0
9a34a8b
 
0c00155
e0d4d2f
c95f3e0
 
 
0c00155
 
3cd1243
9a34a8b
 
 
 
 
 
e0d4d2f
0c00155
4f39124
0c00155
4f39124
0c00155
 
 
4f39124
 
 
72f4c5c
4f39124
73989e5
 
 
 
 
26c0f04
0c00155
 
 
 
2dd8fe8
9a34a8b
0c00155
73989e5
0c00155
3ba1061
73989e5
 
e0d4d2f
e9cd6fd
 
4f39124
e9cd6fd
4f39124
e9cd6fd
3ba1061
9a34a8b
3ba1061
 
0c00155
 
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load SAM 2 model
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

def plot_masks(image, masks):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)
    
    for mask in masks:
        masked = np.ma.masked_where(mask == 0, mask)
        ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
    ax.axis('off')
    plt.close()
    return fig2img(fig)

def segment_everything(input_image):
    try:
        if input_image is None:
            return None, "Please upload an image before submitting."
        
        input_image = Image.fromarray(input_image).convert("RGB")
        
        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
            predictor.set_image(input_image)
            # Use 'everything' prompt
            masks, _, _ = predictor.predict([])
        
        # Plot the results
        result_image = plot_masks(input_image, masks)
        
        return result_image, f"Segmented everything in the image. Found {len(masks)} objects."
    
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=segment_everything,
    inputs=[
        gr.Image(type="numpy", label="Upload an image")
    ],
    outputs=[
        gr.Image(type="pil", label="Segmented Image"),
        gr.Textbox(label="Status")
    ],
    title="SAM 2 Everything Segmentation",
    description="Upload an image to segment all objects using SAM 2."
)

# Launch the interface
iface.launch()