File size: 1,746 Bytes
dfdcd97
a3ee867
e9cd6fd
 
c95f3e0
 
e0d4d2f
c95f3e0
 
 
 
 
 
e0d4d2f
e9cd6fd
c95f3e0
 
 
 
 
e0d4d2f
c95f3e0
 
 
e0d4d2f
c95f3e0
 
 
 
 
 
 
e0d4d2f
c95f3e0
 
e9cd6fd
 
c95f3e0
 
 
 
e9cd6fd
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
 
 
 
 
c95f3e0
 
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def segment_image(input_image, points):
    # Convert input_image to PIL Image
    input_image = Image.fromarray(input_image)
    
    # Prepare inputs
    inputs = processor(input_image, input_points=[points], return_tensors="pt").to(device)
    
    # Generate masks
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process masks
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    scores = outputs.iou_scores
    
    # Convert mask to numpy array
    mask = masks[0][0].numpy()
    
    # Overlay the mask on the original image
    result_image = np.array(input_image)
    mask_rgb = np.zeros_like(result_image)
    mask_rgb[mask > 0.5] = [255, 0, 0]  # Red color for the mask
    result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
    
    return result_image

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy"),
        gr.Image(type="numpy", tool="sketch", brush_radius=5, label="Click on objects to segment")
    ],
    outputs=gr.Image(type="numpy"),
    title="Segment Anything Model (SAM) Image Segmentation",
    description="Click on objects in the image to segment them using SAM."
)

# Launch the interface
iface.launch()