CLIPSeg2 / app.py
sigyllly's picture
Update app.py
4fb1cd4 verified
raw
history blame
3.46 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
import threading
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
# Add your article and description here
gr.Markdown("Your article goes here")
gr.Markdown("Your description goes here")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
positive_prompts = gr.Textbox(
label="Please describe what you want to identify (comma separated)"
)
negative_prompts = gr.Textbox(
label="Please describe what you want to ignore (comma separated)"
)
input_slider_T = gr.Slider(
minimum=0, maximum=1, value=0.4, label="Threshold"
)
btn_process = gr.Button(label="Process")
with gr.Column():
output_image = gr.Image(label="Result")
output_mask = gr.Image(label="Mask")
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
def get_masks(prompts, img, threshold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
btn_process.click(
extract_image,
inputs=[
positive_prompts,
negative_prompts,
input_image,
input_slider_T,
],
outputs=[output_image, output_mask],
)
iface = gr.Interface(
extract_image,
[
gr.Textbox(label="Positive prompts"),
gr.Textbox(label="Negative prompts"),
gr.Image(type="pil"),
gr.Slider(minimum=0, maximum=1, value=0.4, label="Threshold"),
],
[gr.Image(label="Result")], # Only return the final image
"textbox,textbox,image,slider", # Match the directory name (without mask)
"image",
title="CLIPSeg API",
)
# Launch both UI and API
demo.launch()
iface.launch(share=True)