File size: 4,833 Bytes
dfdcd97
a3ee867
e9cd6fd
 
3cd1243
c95f3e0
26c0f04
e0d4d2f
c95f3e0
 
 
3cd1243
 
 
 
 
 
 
e0d4d2f
564688d
 
 
 
 
 
 
 
 
 
26c0f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cd1243
73989e5
 
 
 
 
 
 
 
 
26c0f04
3cd1243
26c0f04
3cd1243
 
 
 
73989e5
 
 
3cd1243
73989e5
 
3cd1243
 
 
 
73989e5
 
3cd1243
 
 
26c0f04
 
 
 
 
3cd1243
 
26c0f04
 
 
 
 
3cd1243
 
26c0f04
73989e5
3cd1243
73989e5
 
 
 
 
 
 
3cd1243
3ba1061
73989e5
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
3ba1061
3cd1243
e9cd6fd
3ba1061
 
 
 
3cd1243
 
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from scipy.ndimage import label, center_of_mass

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load SAM model and processor
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Load BLIP model and processor for image-to-text
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

def process_mask(mask, target_size):
    if mask.ndim > 2:
        mask = mask.squeeze()
    if mask.ndim > 2:
        mask = mask[0]
    mask = (mask > 0.5).astype(np.uint8) * 255
    mask_image = Image.fromarray(mask)
    mask_image = mask_image.resize(target_size, Image.NEAREST)
    return np.array(mask_image) > 0

def is_cat_like(mask, image_area):
    labeled, num_features = label(mask)
    if num_features == 0:
        return False
    
    largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1))
    area = largest_component.sum()
    
    # Check if the area is reasonable for a cat (between 5% and 30% of image)
    if not (0.05 * image_area < area < 0.3 * image_area):
        return False
    
    # Check if the shape is roughly elliptical
    cy, cx = center_of_mass(largest_component)
    major_axis = max(largest_component.shape)
    minor_axis = min(largest_component.shape)
    aspect_ratio = major_axis / minor_axis
    
    return 1.5 < aspect_ratio < 3  # Most cats have an aspect ratio in this range

def segment_image(input_image, object_name):
    try:
        if input_image is None:
            return None, "Please upload an image before submitting."
        
        input_image = Image.fromarray(input_image).convert("RGB")
        original_size = input_image.size
        if not original_size or 0 in original_size:
            return None, "Invalid image size. Please upload a different image."
        
        # Generate detailed image caption
        blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
        caption = blip_model.generate(**blip_inputs, max_length=50)
        caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
        
        # Process the image with SAM
        sam_inputs = sam_processor(input_image, return_tensors="pt").to(device)
        
        # Generate masks
        with torch.no_grad():
            sam_outputs = sam_model(**sam_inputs)
        
        # Post-process masks
        masks = sam_processor.image_processor.post_process_masks(
            sam_outputs.pred_masks.cpu(),
            sam_inputs["original_sizes"].cpu(),
            sam_inputs["reshaped_input_sizes"].cpu()
        )
        
        # Find the mask that best matches the specified object
        best_mask = None
        best_score = -1
        image_area = original_size[0] * original_size[1]
        
        cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty']
        caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words)
        
        for mask in masks[0]:
            mask_binary = mask.numpy() > 0.5
            if is_cat_like(mask_binary, image_area) and caption_contains_cat:
                mask_area = mask_binary.sum()
                if mask_area > best_score:
                    best_mask = mask_binary
                    best_score = mask_area
        
        if best_mask is None:
            return input_image, f"Could not find a suitable '{object_name}' in the image."
        
        combined_mask = process_mask(best_mask, original_size)
        
        # Overlay the mask on the original image
        result_image = np.array(input_image)
        mask_rgb = np.zeros_like(result_image)
        mask_rgb[combined_mask] = [255, 0, 0]  # Red color for the mask
        result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
        
        return result_image, f"Segmented '{object_name}' in the image."
    
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy", label="Upload an image"),
        gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
    ],
    outputs=[
        gr.Image(type="numpy", label="Segmented Image"),
        gr.Textbox(label="Status")
    ],
    title="Segment Anything Model (SAM) with Object Specification",
    description="Upload an image and specify an object to segment."
)

# Launch the interface
iface.launch()