Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoProcessor, CLIPSegForImageSegmentation | |
import traceback | |
# Load the CLIPSeg model and processor | |
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
def segment_image(input_image, text_prompt): | |
try: | |
# Ensure input_image is a PIL Image | |
if not isinstance(input_image, Image.Image): | |
input_image = Image.fromarray(input_image) | |
# Resize image if it's too large | |
max_size = 1024 | |
if max(input_image.size) > max_size: | |
input_image.thumbnail((max_size, max_size)) | |
# Preprocess the image | |
inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt") | |
# Perform segmentation | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the predicted segmentation | |
preds = outputs.logits.squeeze().sigmoid() | |
# Convert the prediction to a numpy array and scale to 0-255 | |
segmentation = (preds.numpy() * 255).astype(np.uint8) | |
# Resize segmentation to match input image size | |
segmentation = Image.fromarray(segmentation).resize(input_image.size) | |
segmentation = np.array(segmentation) | |
# Create a colored heatmap | |
heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) | |
heatmap[:, :, 0] = segmentation # Red channel | |
heatmap[:, :, 2] = 255 - segmentation # Blue channel | |
# Blend the heatmap with the original image | |
original_image = np.array(input_image) | |
blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8) | |
return Image.fromarray(blended) | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}" | |
return Image.new('RGB', (400, 200), color = (255, 0, 0)) # Red image to indicate error | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_image, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmentation Result"), | |
gr.Textbox(label="Error Message", visible=False) | |
], | |
title="CLIPSeg Image Segmentation", | |
description="Upload an image and provide a text prompt to segment objects." | |
) | |
# Launch the interface | |
iface.launch() |