Spaces:
Sleeping
Sleeping
File size: 2,590 Bytes
dfdcd97 7a7f5c3 dfdcd97 7a7f5c3 dfdcd97 7a7f5c3 dfdcd97 7a7f5c3 dfdcd97 7a7f5c3 dfdcd97 7a7f5c3 71905f5 7a7f5c3 71905f5 7a7f5c3 dfdcd97 7a7f5c3 dfdcd97 ad1a7ff dfdcd97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import traceback
# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def segment_image(input_image, text_prompt):
try:
# Ensure input_image is a PIL Image
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# Resize image if it's too large
max_size = 1024
if max(input_image.size) > max_size:
input_image.thumbnail((max_size, max_size))
# Preprocess the image
inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt")
# Perform segmentation
with torch.no_grad():
outputs = model(**inputs)
# Get the predicted segmentation
preds = outputs.logits.squeeze().sigmoid()
# Convert the prediction to a numpy array and scale to 0-255
segmentation = (preds.numpy() * 255).astype(np.uint8)
# Resize segmentation to match input image size
segmentation = Image.fromarray(segmentation).resize(input_image.size)
segmentation = np.array(segmentation)
# Create a colored heatmap
heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
heatmap[:, :, 0] = segmentation # Red channel
heatmap[:, :, 2] = 255 - segmentation # Blue channel
# Blend the heatmap with the original image
original_image = np.array(input_image)
blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8)
return Image.fromarray(blended)
except Exception as e:
error_msg = f"An error occurred: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}"
return Image.new('RGB', (400, 200), color = (255, 0, 0)) # Red image to indicate error
# Create Gradio interface
iface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...")
],
outputs=[
gr.Image(type="pil", label="Segmentation Result"),
gr.Textbox(label="Error Message", visible=False)
],
title="CLIPSeg Image Segmentation",
description="Upload an image and provide a text prompt to segment objects."
)
# Launch the interface
iface.launch() |