CLIPSeg2 / app.py
sigyllly's picture
Update app.py
4feabc5 verified
raw
history blame
2.78 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
import threading
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
def get_masks(prompts, img, threshold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threshold
masks.append(mask)
return masks
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
iface = gr.Interface(
fn=extract_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Positive Prompts (comma separated)"),
gr.Textbox(label="Negative Prompts (comma separated)"),
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"),
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Image(type="pil", label="Output Mask"),
],
)
# Launch Gradio UI
iface.launch()
# Define API interface
api_interface = gr.Interface(
fn=extract_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Positive Prompts (comma separated)"),
gr.Textbox(label="Negative Prompts (comma separated)"),
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"),
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Image(type="pil", label="Output Mask"),
],
live=True # Setting live to True enables the API endpoint
)
# Launch API
api_interface.launch(share=True) # share=True allows external access to the API