File size: 3,854 Bytes
5c75869
 
 
 
 
3e99e39
 
 
5c75869
 
 
 
3e99e39
5c75869
b09d090
 
6adb478
 
b09d090
 
48a7936
5c75869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48a7936
5c75869
3e99e39
48a7936
 
 
 
3e99e39
48a7936
 
 
3e99e39
48a7936
 
5c75869
48a7936
 
 
 
 
 
 
 
5c75869
3e99e39
5c75869
 
 
 
 
 
 
 
48a7936
 
 
 
 
 
 
5c75869
 
 
 
 
 
48a7936
5c75869
 
 
48a7936
5c75869
48a7936
 
5c75869
 
 
48a7936
5c75869
b09d090
3e99e39
 
 
 
 
 
 
 
 
 
5c75869
3e99e39
 
5c75869
3e99e39
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
from flask import Flask, request, jsonify, send_file
from io import BytesIO
import threading

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

app = Flask(__name__)


# Define article as a global variable
title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"

def process_image(image, prompt):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )

    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)
    return mask

def get_masks(prompts, img, threshold):
    prompts = prompts.split(",")
    masks = []
    for prompt in prompts:
        mask = process_image(img, prompt)
        mask = mask > threshold
        masks.append(mask)
    return masks

def extract_image(pos_prompts, neg_prompts, img, threshold):
    positive_masks = get_masks(pos_prompts, img, 0.5)
    negative_masks = get_masks(neg_prompts, img, 0.5)

    pos_mask = np.any(np.stack(positive_masks), axis=0)
    neg_mask = np.any(np.stack(negative_masks), axis=0)
    final_mask = pos_mask & ~neg_mask

    final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
    output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
    output_image.paste(img, mask=final_mask)
    return output_image, final_mask

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
    gr.Markdown(article)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            positive_prompts = gr.Textbox(
                label="Please describe what you want to identify (comma separated)"
            )
            negative_prompts = gr.Textbox(
                label="Please describe what you want to ignore (comma separated)"
            )

            input_slider_T = gr.Slider(
                minimum=0, maximum=1, value=0.4, label="Threshold"
            )
            btn_process = gr.Button(label="Process")

        with gr.Column():
            output_image = gr.Image(label="Result")
            output_mask = gr.Image(label="Mask")

    btn_process.click(
        extract_image,
        inputs=[
            positive_prompts,
            negative_prompts,
            input_image,
            input_slider_T,
        ],
        outputs=[output_image, output_mask],
    )

def run_demo():
    demo.launch()

def run_flask():
    app.run(host='127.0.0.1', port=7860)

if __name__ == '__main__':
    # Run Gradio UI and Flask in separate threads
    gr_thread = threading.Thread(target=run_demo)
    flask_thread = threading.Thread(target=run_flask)

    gr_thread.start()
    flask_thread.start()

    gr_thread.join()
    flask_thread.join()