CLIPSeg2 / app.py
sigyllly's picture
Update app.py
9f97f60 verified
raw
history blame
2.8 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
import threading
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
def get_masks(prompts, img, threshold):
prompts = [p.strip() for p in prompts.split(",")]
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray((final_mask * 255).astype(np.uint8), "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
iface = gr.Interface(
fn=extract_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Positive Prompts (comma separated)"),
gr.Textbox(label="Negative Prompts (comma separated)"),
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"),
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Image(type="pil", label="Output Mask"),
],
)
# Define an API endpoint
api_interface = gr.Interface(
fn=extract_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Positive Prompts (comma separated)"),
gr.Textbox(label="Negative Prompts (comma separated)"),
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"),
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Image(type="pil", label="Output Mask"),
],
live=True # Setting live to True enables the API endpoint
)
# Run the Gradio UI and API
iface.launch()
api_interface.launch(share=True) # share=True allows external access to the API