File size: 3,464 Bytes
5c75869
 
 
 
 
3e99e39
5c75869
 
 
 
4fb1cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b09d090
4fb1cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bde246
4fb1cd4
 
 
 
 
 
6e7b1c7
4fb1cd4
 
 
 
6e7b1c7
 
4fb1cd4
 
 
6bde246
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
import threading

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")

    # Add your article and description here
    gr.Markdown("Your article goes here")
    gr.Markdown("Your description goes here")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            positive_prompts = gr.Textbox(
                label="Please describe what you want to identify (comma separated)"
            )
            negative_prompts = gr.Textbox(
                label="Please describe what you want to ignore (comma separated)"
            )
            input_slider_T = gr.Slider(
                minimum=0, maximum=1, value=0.4, label="Threshold"
            )
            btn_process = gr.Button(label="Process")

        with gr.Column():
            output_image = gr.Image(label="Result")
            output_mask = gr.Image(label="Mask")

    def process_image(image, prompt):
        inputs = processor(
            text=prompt, images=image, padding="max_length", return_tensors="pt"
        )
        with torch.no_grad():
            outputs = model(**inputs)
            preds = outputs.logits

        pred = torch.sigmoid(preds)
        mat = pred.cpu().numpy()
        mask = Image.fromarray(np.uint8(mat * 255), "L")
        mask = mask.convert("RGB")
        mask = mask.resize(image.size)
        mask = np.array(mask)[:, :, 0]

        mask_min = mask.min()
        mask_max = mask.max()
        mask = (mask - mask_min) / (mask_max - mask_min)

        return mask

    def get_masks(prompts, img, threshold):
        prompts = prompts.split(",")
        masks = []
        for prompt in prompts:
            mask = process_image(img, prompt)
            mask = mask > threshold
            masks.append(mask)

        return masks

    def extract_image(pos_prompts, neg_prompts, img, threshold):
        positive_masks = get_masks(pos_prompts, img, 0.5)
        negative_masks = get_masks(neg_prompts, img, 0.5)

        pos_mask = np.any(np.stack(positive_masks), axis=0)
        neg_mask = np.any(np.stack(negative_masks), axis=0)
        final_mask = pos_mask & ~neg_mask

        final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
        output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
        output_image.paste(img, mask=final_mask)

        return output_image, final_mask

    btn_process.click(
        extract_image,
        inputs=[
            positive_prompts,
            negative_prompts,
            input_image,
            input_slider_T,
        ],
        outputs=[output_image, output_mask],
    )
iface = gr.Interface(
    extract_image,
    [
        gr.Textbox(label="Positive prompts"),
        gr.Textbox(label="Negative prompts"),
        gr.Image(type="pil"),
        gr.Slider(minimum=0, maximum=1, value=0.4, label="Threshold"),
    ],
    [gr.Image(label="Result")],  # Only return the final image
    "textbox,textbox,image,slider",  # Match the directory name (without mask)
    "image",
    title="CLIPSeg API",
)

# Launch both UI and API
demo.launch()
iface.launch(share=True)