File size: 2,159 Bytes
dfdcd97
a3ee867
e9cd6fd
 
c95f3e0
 
e0d4d2f
c95f3e0
 
 
 
 
 
e0d4d2f
7bee2b4
c95f3e0
 
 
7bee2b4
 
 
 
 
 
 
 
e0d4d2f
c95f3e0
 
 
e0d4d2f
c95f3e0
 
 
 
 
 
e0d4d2f
c95f3e0
7bee2b4
 
 
 
 
 
e9cd6fd
 
c95f3e0
 
7bee2b4
c95f3e0
e9cd6fd
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
 
7bee2b4
e9cd6fd
 
c95f3e0
7bee2b4
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def segment_image(input_image, segment_anything):
    # Convert input_image to PIL Image
    input_image = Image.fromarray(input_image)
    
    if segment_anything:
        # Segment everything in the image
        inputs = processor(input_image, return_tensors="pt").to(device)
    else:
        # Use the center of the image as a point prompt
        height, width = input_image.size
        center_point = [[width // 2, height // 2]]
        inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
    
    # Generate masks
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Post-process masks
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    
    # Convert mask to numpy array
    if segment_anything:
        # Combine all masks
        combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
    else:
        # Use the first mask
        combined_mask = masks[0][0].numpy() > 0.5
    
    # Overlay the mask on the original image
    result_image = np.array(input_image)
    mask_rgb = np.zeros_like(result_image)
    mask_rgb[combined_mask] = [255, 0, 0]  # Red color for the mask
    result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
    
    return result_image

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy"),
        gr.Checkbox(label="Segment Everything")
    ],
    outputs=gr.Image(type="numpy"),
    title="Segment Anything Model (SAM) Image Segmentation",
    description="Upload an image and choose whether to segment everything or use a center point."
)

# Launch the interface
iface.launch()