File size: 3,208 Bytes
dfdcd97
a3ee867
e9cd6fd
 
c95f3e0
 
e0d4d2f
c95f3e0
 
 
 
 
 
e0d4d2f
7bee2b4
73989e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ba1061
73989e5
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
3ba1061
7bee2b4
e9cd6fd
3ba1061
 
 
 
c95f3e0
7bee2b4
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor
from PIL import Image

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def segment_image(input_image, segment_anything):
    try:
        if input_image is None:
            return None, "Please upload an image before submitting."
        
        # Convert input_image to PIL Image
        input_image = Image.fromarray(input_image).convert("RGB")
        
        # Store original size
        original_size = input_image.size
        if not original_size or 0 in original_size:
            return None, "Invalid image size. Please upload a different image."
        
        if segment_anything:
            # Segment everything in the image
            inputs = processor(input_image, return_tensors="pt").to(device)
        else:
            # Use the center of the image as a point prompt
            width, height = original_size
            center_point = [[width // 2, height // 2]]
            inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
        
        # Generate masks
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Post-process masks
        masks = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )
        
        # Convert mask to numpy array and resize to match original image
        if segment_anything:
            # Combine all masks
            combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
        else:
            # Use the first mask
            combined_mask = masks[0][0].numpy() > 0.5
        
        # Ensure mask is 2D
        if combined_mask.ndim > 2:
            combined_mask = combined_mask.squeeze()
        
        # Resize mask to match original image size
        combined_mask = cv2.resize(combined_mask.astype(np.uint8), (original_size[0], original_size[1])) > 0
        
        # Overlay the mask on the original image
        result_image = np.array(input_image)
        mask_rgb = np.zeros_like(result_image)
        mask_rgb[combined_mask] = [255, 0, 0]  # Red color for the mask
        result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
        
        return result_image, "Segmentation completed successfully."
    
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy", label="Upload an image"),
        gr.Checkbox(label="Segment Everything")
    ],
    outputs=[
        gr.Image(type="numpy", label="Segmented Image"),
        gr.Textbox(label="Status")
    ],
    title="Segment Anything Model (SAM) Image Segmentation",
    description="Upload an image and choose whether to segment everything or use a center point."
)

# Launch the interface
iface.launch()