File size: 2,952 Bytes
dfdcd97
 
dc23d39
a3ee867
dfdcd97
 
 
 
 
 
a3ee867
 
 
 
 
 
 
7a7f5c3
4db07a0
 
dc23d39
 
a3ee867
 
 
dc23d39
 
a3ee867
7a7f5c3
54db35c
a3ee867
 
1d4c655
 
 
 
 
 
 
 
54db35c
1d4c655
 
 
 
 
 
54db35c
1d4c655
54db35c
dfdcd97
a3ee867
 
 
 
4db07a0
 
 
 
 
 
a3ee867
 
 
 
 
1d4c655
a3ee867
 
 
 
 
 
 
4db07a0
a3ee867
 
 
 
54db35c
a3ee867
 
dfdcd97
a3ee867
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, CLIPSegForImageSegmentation

# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def segment_everything(image):
    inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.squeeze().sigmoid()
    segmentation = (preds.numpy() * 255).astype(np.uint8)
    return Image.fromarray(segmentation)

def segment_box(image, x1, y1, x2, y2):
    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    cropped_image = image[y1:y2, x1:x2]
    inputs = processor(text=["object"], images=[cropped_image], padding="max_length", return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.squeeze().sigmoid()
    segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
    segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
    return Image.fromarray(segmentation)

def update_image(image, segmentation):
    if segmentation is None:
        return image
    
    # Ensure image is in the correct format (PIL Image)
    if isinstance(image, np.ndarray):
        image_pil = Image.fromarray((image * 255).astype(np.uint8))
    else:
        image_pil = image
    
    # Convert segmentation to RGBA
    seg_pil = Image.fromarray(segmentation).convert('RGBA')
    
    # Resize segmentation to match input image if necessary
    if image_pil.size != seg_pil.size:
        seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
    
    # Blend images
    blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
    
    return np.array(blended)

with gr.Blocks() as demo:
    gr.Markdown("# Segment Anything-like Demo")
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Input Image")
            with gr.Row():
                x1_input = gr.Number(label="X1")
                y1_input = gr.Number(label="Y1")
                x2_input = gr.Number(label="X2")
                y2_input = gr.Number(label="Y2")
            with gr.Row():
                everything_btn = gr.Button("Everything")
                box_btn = gr.Button("Box")
        with gr.Column(scale=1):
            output_image = gr.Image(label="Segmentation Result")

    everything_btn.click(
        fn=segment_everything,
        inputs=[input_image],
        outputs=[output_image]
    )
    box_btn.click(
        fn=segment_box,
        inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
        outputs=[output_image]
    )
    output_image.change(
        fn=update_image,
        inputs=[input_image, output_image],
        outputs=[output_image]
    )

demo.launch()