Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
from transformers import SamModel, SamProcessor, BlipProcessor, BlipForConditionalGeneration | |
from PIL import Image | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load SAM model and processor | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# Load BLIP model and processor for image-to-text | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) | |
def process_mask(mask, target_size): | |
if mask.ndim > 2: | |
mask = mask.squeeze() | |
if mask.ndim > 2: | |
mask = mask[0] | |
mask = (mask > 0.5).astype(np.uint8) * 255 | |
mask_image = Image.fromarray(mask) | |
mask_image = mask_image.resize(target_size, Image.NEAREST) | |
return np.array(mask_image) > 0 | |
def segment_image(input_image, object_name): | |
try: | |
if input_image is None: | |
return None, "Please upload an image before submitting." | |
input_image = Image.fromarray(input_image).convert("RGB") | |
original_size = input_image.size | |
if not original_size or 0 in original_size: | |
return None, "Invalid image size. Please upload a different image." | |
# Generate image caption | |
blip_inputs = blip_processor(input_image, return_tensors="pt").to(device) | |
caption = blip_model.generate(**blip_inputs) | |
caption_text = blip_processor.decode(caption[0], skip_special_tokens=True) | |
# Process the image with SAM | |
sam_inputs = sam_processor(input_image, return_tensors="pt").to(device) | |
# Generate masks | |
with torch.no_grad(): | |
sam_outputs = sam_model(**sam_inputs) | |
# Post-process masks | |
masks = sam_processor.image_processor.post_process_masks( | |
sam_outputs.pred_masks.cpu(), | |
sam_inputs["original_sizes"].cpu(), | |
sam_inputs["reshaped_input_sizes"].cpu() | |
) | |
# Find the mask that best matches the specified object | |
best_mask = None | |
best_score = -1 | |
for mask in masks[0]: | |
mask_binary = mask.numpy() > 0.5 | |
mask_area = mask_binary.sum() | |
if object_name.lower() in caption_text.lower() and mask_area > best_score: | |
best_mask = mask_binary | |
best_score = mask_area | |
if best_mask is None: | |
return input_image, f"Could not find '{object_name}' in the image." | |
combined_mask = process_mask(best_mask, original_size) | |
# Overlay the mask on the original image | |
result_image = np.array(input_image) | |
mask_rgb = np.zeros_like(result_image) | |
mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask | |
result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0) | |
return result_image, f"Segmented '{object_name}' in the image." | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_image, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload an image"), | |
gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)") | |
], | |
outputs=[ | |
gr.Image(type="numpy", label="Segmented Image"), | |
gr.Textbox(label="Status") | |
], | |
title="Segment Anything Model (SAM) with Object Specification", | |
description="Upload an image and specify an object to segment." | |
) | |
# Launch the interface | |
iface.launch() |