File size: 2,590 Bytes
dfdcd97
 
 
 
 
7a7f5c3
dfdcd97
 
 
 
 
 
7a7f5c3
 
 
 
dfdcd97
7a7f5c3
 
 
 
dfdcd97
7a7f5c3
 
dfdcd97
7a7f5c3
 
 
dfdcd97
7a7f5c3
 
71905f5
7a7f5c3
 
71905f5
7a7f5c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfdcd97
 
 
 
 
 
 
 
7a7f5c3
 
 
 
dfdcd97
ad1a7ff
dfdcd97
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import traceback

# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def segment_image(input_image, text_prompt):
    try:
        # Ensure input_image is a PIL Image
        if not isinstance(input_image, Image.Image):
            input_image = Image.fromarray(input_image)

        # Resize image if it's too large
        max_size = 1024
        if max(input_image.size) > max_size:
            input_image.thumbnail((max_size, max_size))

        # Preprocess the image
        inputs = processor(text=[text_prompt], images=[input_image], padding="max_length", return_tensors="pt")

        # Perform segmentation
        with torch.no_grad():
            outputs = model(**inputs)

        # Get the predicted segmentation
        preds = outputs.logits.squeeze().sigmoid()

        # Convert the prediction to a numpy array and scale to 0-255
        segmentation = (preds.numpy() * 255).astype(np.uint8)

        # Resize segmentation to match input image size
        segmentation = Image.fromarray(segmentation).resize(input_image.size)
        segmentation = np.array(segmentation)

        # Create a colored heatmap
        heatmap = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
        heatmap[:, :, 0] = segmentation  # Red channel
        heatmap[:, :, 2] = 255 - segmentation  # Blue channel

        # Blend the heatmap with the original image
        original_image = np.array(input_image)
        blended = (0.7 * original_image + 0.3 * heatmap).astype(np.uint8)

        return Image.fromarray(blended)
    except Exception as e:
        error_msg = f"An error occurred: {str(e)}\n\nStacktrace:\n{traceback.format_exc()}"
        return Image.new('RGB', (400, 200), color = (255, 0, 0))  # Red image to indicate error

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Textbox(label="Text Prompt", placeholder="Enter a description of what to segment...")
    ],
    outputs=[
        gr.Image(type="pil", label="Segmentation Result"),
        gr.Textbox(label="Error Message", visible=False)
    ],
    title="CLIPSeg Image Segmentation",
    description="Upload an image and provide a text prompt to segment objects."
)

# Launch the interface
iface.launch()