File size: 3,574 Bytes
dfdcd97
 
dc23d39
a3ee867
dfdcd97
 
 
 
 
 
a3ee867
73a8e2b
 
 
 
 
 
 
 
a3ee867
 
 
 
 
 
7a7f5c3
4db07a0
73a8e2b
 
 
 
 
 
 
 
4db07a0
dc23d39
73a8e2b
a3ee867
 
 
dc23d39
 
a3ee867
7a7f5c3
54db35c
a3ee867
 
1d4c655
73a8e2b
 
 
 
1d4c655
 
73a8e2b
1d4c655
 
 
 
54db35c
1d4c655
 
 
 
 
 
54db35c
1d4c655
54db35c
dfdcd97
a3ee867
 
 
 
4db07a0
 
 
 
 
 
a3ee867
 
 
 
 
1d4c655
a3ee867
 
 
 
 
 
 
4db07a0
a3ee867
 
 
 
54db35c
a3ee867
 
dfdcd97
a3ee867
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, CLIPSegForImageSegmentation

# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def segment_everything(image):
    # Check if image is a list and extract the actual image data
    if isinstance(image, list):
        image = image[0]
    
    # Convert numpy array to PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.squeeze().sigmoid()
    segmentation = (preds.numpy() * 255).astype(np.uint8)
    return Image.fromarray(segmentation)

def segment_box(image, x1, y1, x2, y2):
    # Check if image is a list and extract the actual image data
    if isinstance(image, list):
        image = image[0]
    
    # Convert PIL Image to numpy array if necessary
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    cropped_image = image[y1:y2, x1:x2]
    inputs = processor(text=["object"], images=[Image.fromarray(cropped_image)], padding="max_length", return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.squeeze().sigmoid()
    segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
    segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
    return Image.fromarray(segmentation)

def update_image(image, segmentation):
    if segmentation is None:
        return image
    
    # Check if image is a list and extract the actual image data
    if isinstance(image, list):
        image = image[0]
    
    # Ensure image is in the correct format (PIL Image)
    if isinstance(image, np.ndarray):
        image_pil = Image.fromarray(image)
    else:
        image_pil = image
    
    # Convert segmentation to RGBA
    seg_pil = Image.fromarray(segmentation).convert('RGBA')
    
    # Resize segmentation to match input image if necessary
    if image_pil.size != seg_pil.size:
        seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
    
    # Blend images
    blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
    
    return np.array(blended)

with gr.Blocks() as demo:
    gr.Markdown("# Segment Anything-like Demo")
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Input Image")
            with gr.Row():
                x1_input = gr.Number(label="X1")
                y1_input = gr.Number(label="Y1")
                x2_input = gr.Number(label="X2")
                y2_input = gr.Number(label="Y2")
            with gr.Row():
                everything_btn = gr.Button("Everything")
                box_btn = gr.Button("Box")
        with gr.Column(scale=1):
            output_image = gr.Image(label="Segmentation Result")

    everything_btn.click(
        fn=segment_everything,
        inputs=[input_image],
        outputs=[output_image]
    )
    box_btn.click(
        fn=segment_box,
        inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
        outputs=[output_image]
    )
    output_image.change(
        fn=update_image,
        inputs=[input_image, output_image],
        outputs=[output_image]
    )

demo.launch()