File size: 3,508 Bytes
5c75869 3e99e39 5c75869 3e99e39 5c75869 b09d090 48a7936 5c75869 48a7936 5c75869 3e99e39 48a7936 3e99e39 48a7936 3e99e39 48a7936 5c75869 48a7936 5c75869 3e99e39 5c75869 48a7936 5c75869 48a7936 5c75869 48a7936 5c75869 48a7936 5c75869 48a7936 5c75869 b09d090 3e99e39 5c75869 3e99e39 5c75869 3e99e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
from flask import Flask, request, jsonify, send_file
from io import BytesIO
import threading
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
app = Flask(__name__)
# Define article as a global variable
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
def get_masks(prompts, img, threshold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
gr.Markdown(article)
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
positive_prompts = gr.Textbox(
label="Please describe what you want to identify (comma separated)"
)
negative_prompts = gr.Textbox(
label="Please describe what you want to ignore (comma separated)"
)
input_slider_T = gr.Slider(
minimum=0, maximum=1, value=0.4, label="Threshold"
)
btn_process = gr.Button(label="Process")
with gr.Column():
output_image = gr.Image(label="Result")
output_mask = gr.Image(label="Mask")
btn_process.click(
extract_image,
inputs=[
positive_prompts,
negative_prompts,
input_image,
input_slider_T,
],
outputs=[output_image, output_mask],
)
def run_demo():
demo.launch()
def run_flask():
app.run(host='127.0.0.1', port=7860)
if __name__ == '__main__':
# Run Gradio UI and Flask in separate threads
gr_thread = threading.Thread(target=run_demo)
flask_thread = threading.Thread(target=run_flask)
gr_thread.start()
flask_thread.start()
gr_thread.join()
flask_thread.join()
|