File size: 3,013 Bytes
5c75869
 
 
 
 
3e99e39
5c75869
 
 
 
3e99e39
5c75869
 
1cce0ac
 
 
 
5c75869
 
 
 
48a7936
 
 
 
 
 
 
5c75869
 
 
 
 
 
48a7936
5c75869
 
1cce0ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c75869
48a7936
5c75869
48a7936
 
5c75869
 
 
48a7936
5c75869
b09d090
1cce0ac
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
import threading

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")

    # Add your article and description here
    gr.Markdown("Your article goes here")
    gr.Markdown("Your description goes here")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            positive_prompts = gr.Textbox(
                label="Please describe what you want to identify (comma separated)"
            )
            negative_prompts = gr.Textbox(
                label="Please describe what you want to ignore (comma separated)"
            )

            input_slider_T = gr.Slider(
                minimum=0, maximum=1, value=0.4, label="Threshold"
            )
            btn_process = gr.Button(label="Process")

        with gr.Column():
            output_image = gr.Image(label="Result")
            output_mask = gr.Image(label="Mask")

    def process_image(image, prompt):
        inputs = processor(
            text=prompt, images=image, padding="max_length", return_tensors="pt"
        )

        with torch.no_grad():
            outputs = model(**inputs)
            preds = outputs.logits

        pred = torch.sigmoid(preds)
        mat = pred.cpu().numpy()
        mask = Image.fromarray(np.uint8(mat * 255), "L")
        mask = mask.convert("RGB")
        mask = mask.resize(image.size)
        mask = np.array(mask)[:, :, 0]

        mask_min = mask.min()
        mask_max = mask.max()
        mask = (mask - mask_min) / (mask_max - mask_min)
        return mask

    def get_masks(prompts, img, threshold):
        prompts = prompts.split(",")
        masks = []
        for prompt in prompts:
            mask = process_image(img, prompt)
            mask = mask > threshold
            masks.append(mask)
        return masks

    def extract_image(pos_prompts, neg_prompts, img, threshold):
        positive_masks = get_masks(pos_prompts, img, 0.5)
        negative_masks = get_masks(neg_prompts, img, 0.5)

        pos_mask = np.any(np.stack(positive_masks), axis=0)
        neg_mask = np.any(np.stack(negative_masks), axis=0)
        final_mask = pos_mask & ~neg_mask

        final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
        output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
        output_image.paste(img, mask=final_mask)
        return output_image, final_mask

    btn_process.click(
        extract_image,
        inputs=[
            positive_prompts,
            negative_prompts,
            input_image,
            input_slider_T,
        ],
        outputs=[output_image, output_mask],
    )

# Launch Gradio API
demo.launch(share=True)