Spaces:
Running
Running
File size: 4,833 Bytes
dfdcd97 a3ee867 e9cd6fd 3cd1243 c95f3e0 26c0f04 e0d4d2f c95f3e0 3cd1243 e0d4d2f 564688d 26c0f04 3cd1243 73989e5 26c0f04 3cd1243 26c0f04 3cd1243 73989e5 3cd1243 73989e5 3cd1243 73989e5 3cd1243 26c0f04 3cd1243 26c0f04 3cd1243 26c0f04 73989e5 3cd1243 73989e5 3cd1243 3ba1061 73989e5 e0d4d2f e9cd6fd e0d4d2f e9cd6fd 3ba1061 3cd1243 e9cd6fd 3ba1061 3cd1243 e0d4d2f e9cd6fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from scipy.ndimage import label, center_of_mass
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load SAM model and processor
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Load BLIP model and processor for image-to-text
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
def process_mask(mask, target_size):
if mask.ndim > 2:
mask = mask.squeeze()
if mask.ndim > 2:
mask = mask[0]
mask = (mask > 0.5).astype(np.uint8) * 255
mask_image = Image.fromarray(mask)
mask_image = mask_image.resize(target_size, Image.NEAREST)
return np.array(mask_image) > 0
def is_cat_like(mask, image_area):
labeled, num_features = label(mask)
if num_features == 0:
return False
largest_component = (labeled == (np.bincount(labeled.flatten())[1:].argmax() + 1))
area = largest_component.sum()
# Check if the area is reasonable for a cat (between 5% and 30% of image)
if not (0.05 * image_area < area < 0.3 * image_area):
return False
# Check if the shape is roughly elliptical
cy, cx = center_of_mass(largest_component)
major_axis = max(largest_component.shape)
minor_axis = min(largest_component.shape)
aspect_ratio = major_axis / minor_axis
return 1.5 < aspect_ratio < 3 # Most cats have an aspect ratio in this range
def segment_image(input_image, object_name):
try:
if input_image is None:
return None, "Please upload an image before submitting."
input_image = Image.fromarray(input_image).convert("RGB")
original_size = input_image.size
if not original_size or 0 in original_size:
return None, "Invalid image size. Please upload a different image."
# Generate detailed image caption
blip_inputs = blip_processor(input_image, return_tensors="pt").to(device)
caption = blip_model.generate(**blip_inputs, max_length=50)
caption_text = blip_processor.decode(caption[0], skip_special_tokens=True)
# Process the image with SAM
sam_inputs = sam_processor(input_image, return_tensors="pt").to(device)
# Generate masks
with torch.no_grad():
sam_outputs = sam_model(**sam_inputs)
# Post-process masks
masks = sam_processor.image_processor.post_process_masks(
sam_outputs.pred_masks.cpu(),
sam_inputs["original_sizes"].cpu(),
sam_inputs["reshaped_input_sizes"].cpu()
)
# Find the mask that best matches the specified object
best_mask = None
best_score = -1
image_area = original_size[0] * original_size[1]
cat_related_words = ['cat', 'kitten', 'feline', 'tabby', 'kitty']
caption_contains_cat = any(word in caption_text.lower() for word in cat_related_words)
for mask in masks[0]:
mask_binary = mask.numpy() > 0.5
if is_cat_like(mask_binary, image_area) and caption_contains_cat:
mask_area = mask_binary.sum()
if mask_area > best_score:
best_mask = mask_binary
best_score = mask_area
if best_mask is None:
return input_image, f"Could not find a suitable '{object_name}' in the image."
combined_mask = process_mask(best_mask, original_size)
# Overlay the mask on the original image
result_image = np.array(input_image)
mask_rgb = np.zeros_like(result_image)
mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
return result_image, f"Segmented '{object_name}' in the image."
except Exception as e:
return None, f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="numpy", label="Upload an image"),
gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
],
outputs=[
gr.Image(type="numpy", label="Segmented Image"),
gr.Textbox(label="Status")
],
title="Segment Anything Model (SAM) with Object Specification",
description="Upload an image and specify an object to segment."
)
# Launch the interface
iface.launch() |