SegmentVision / app.py
sagar007's picture
Update app.py
7a7f5c3 verified
raw
history blame
2.59 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import traceback
# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def segment_image(input_image, text_prompt):
try:
# Ensure input_image is a PIL Image
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# Resize image if it's too large
max_size = 1024
if max(input_image.size) > max_size:
input_image.thumbnail((max_size, max_size))
# Preprocess the image
inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt")
# Perform segmentation
with torch.no_grad():
outputs = model(**inputs)
# Get the predicted segmentation
preds = outputs.logits.squeeze().sigmoid()
# Convert the prediction to a numpy array and scale to 0-255
segmentation = (preds.numpy() * 255).astype(np.uint8)
# Resize segmentation to match input image size
segmentation = Image.fromarray(segmentation).resize(input_image.size)
segmentation = np.array(segmentation)
# Create a colored heatmap
heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
heatmap[:, :, 0] = segmentation # Red channel
heatmap[:, :, 2] = 255 - segmentation # Blue channel
# Blend the heatmap with the original image
original_image = np.array(input_image)
blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8)
return Image.fromarray(blended)
except Exception as e:
error_msg = f"An error occurred: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}"
return Image.new('RGB', (400, 200), color = (255, 0, 0)) # Red image to indicate error
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...")
],
outputs=[
gr.Image(type="pil", label="Segmentation Result"),
gr.Textbox(label="Error Message", visible=False)
],
title="CLIPSeg Image Segmentation",
description="Upload an image and provide a text prompt to segment objects."
)
# Launch the interface
iface.launch()