File size: 3,281 Bytes
dfdcd97
 
dc23d39
a3ee867
dfdcd97
 
 
 
 
 
2af5758
 
 
 
a3ee867
73a8e2b
 
 
 
 
 
49e0cdb
a3ee867
 
2af5758
a3ee867
 
7a7f5c3
4db07a0
73a8e2b
 
 
 
 
 
4db07a0
dc23d39
49e0cdb
a3ee867
 
2af5758
dc23d39
 
a3ee867
7a7f5c3
54db35c
a3ee867
 
1d4c655
73a8e2b
 
 
1d4c655
73a8e2b
1d4c655
 
 
54db35c
1d4c655
49e0cdb
1d4c655
 
54db35c
1d4c655
54db35c
dfdcd97
a3ee867
 
 
 
4db07a0
 
 
 
 
 
a3ee867
 
 
 
 
1d4c655
a3ee867
 
 
 
 
 
 
4db07a0
a3ee867
 
 
 
54db35c
a3ee867
 
dfdcd97
49e0cdb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, CLIPSegForImageSegmentation

# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

# Ensure that the model uses GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def segment_everything(image):
    if isinstance(image, list):
        image = image[0]
    
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    inputs = processor(text=["object"], images=image, padding="max_length", return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.squeeze().sigmoid().cpu()
    segmentation = (preds.numpy() * 255).astype(np.uint8)
    return Image.fromarray(segmentation)

def segment_box(image, x1, y1, x2, y2):
    if isinstance(image, list):
        image = image[0]
    
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    cropped_image = image[y1:y2, x1:x2]
    inputs = processor(text=["object"], images=Image.fromarray(cropped_image), padding="max_length", return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.squeeze().sigmoid().cpu()
    segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
    segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
    return Image.fromarray(segmentation)

def update_image(image, segmentation):
    if segmentation is None:
        return image
    
    if isinstance(image, list):
        image = image[0]
    
    if isinstance(image, np.ndarray):
        image_pil = Image.fromarray(image)
    else:
        image_pil = image
    
    seg_pil = Image.fromarray(segmentation).convert('RGBA')
    
    if image_pil.size!= seg_pil.size:
        seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
    
    blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
    
    return np.array(blended)

with gr.Blocks() as demo:
    gr.Markdown("# Segment Anything-like Demo")
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Input Image")
            with gr.Row():
                x1_input = gr.Number(label="X1")
                y1_input = gr.Number(label="Y1")
                x2_input = gr.Number(label="X2")
                y2_input = gr.Number(label="Y2")
            with gr.Row():
                everything_btn = gr.Button("Everything")
                box_btn = gr.Button("Box")
        with gr.Column(scale=1):
            output_image = gr.Image(label="Segmentation Result")

    everything_btn.click(
        fn=segment_everything,
        inputs=[input_image],
        outputs=[output_image]
    )
    box_btn.click(
        fn=segment_box,
        inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
        outputs=[output_image]
    )
    output_image.change(
        fn=update_image,
        inputs=[input_image, output_image],
        outputs=[output_image]
    )

demo.launch()