|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
import threading |
|
|
|
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
def process_image(image, prompt): |
|
inputs = processor( |
|
text=prompt, images=image, padding="max_length", return_tensors="pt" |
|
) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
preds = outputs.logits |
|
pred = torch.sigmoid(preds) |
|
mat = pred.cpu().numpy() |
|
mask = Image.fromarray(np.uint8(mat * 255), "L") |
|
mask = mask.convert("RGB") |
|
mask = mask.resize(image.size) |
|
mask = np.array(mask)[:, :, 0] |
|
mask_min = mask.min() |
|
mask_max = mask.max() |
|
mask = (mask - mask_min) / (mask_max - mask_min) |
|
return mask |
|
|
|
def get_masks(prompts, img, threshold): |
|
prompts = [p.strip() for p in prompts.split(",")] |
|
masks = [] |
|
for prompt in prompts: |
|
mask = process_image(img, prompt) |
|
mask = mask > threshold |
|
masks.append(mask) |
|
return masks |
|
|
|
def extract_image(pos_prompts, neg_prompts, img, threshold): |
|
positive_masks = get_masks(pos_prompts, img, 0.5) |
|
negative_masks = get_masks(neg_prompts, img, 0.5) |
|
|
|
pos_mask = np.any(np.stack(positive_masks), axis=0) |
|
neg_mask = np.any(np.stack(negative_masks), axis=0) |
|
final_mask = pos_mask & ~neg_mask |
|
|
|
final_mask = Image.fromarray((final_mask * 255).astype(np.uint8), "L") |
|
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
|
output_image.paste(img, mask=final_mask) |
|
return output_image, final_mask |
|
|
|
iface = gr.Interface( |
|
fn=extract_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Textbox(label="Positive Prompts (comma separated)"), |
|
gr.Textbox(label="Negative Prompts (comma separated)"), |
|
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"), |
|
], |
|
outputs=[ |
|
gr.Image(type="pil", label="Output Image"), |
|
gr.Image(type="pil", label="Output Mask"), |
|
], |
|
) |
|
|
|
|
|
api_interface = gr.Interface( |
|
fn=extract_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Textbox(label="Positive Prompts (comma separated)"), |
|
gr.Textbox(label="Negative Prompts (comma separated)"), |
|
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"), |
|
], |
|
outputs=[ |
|
gr.Image(type="pil", label="Output Image"), |
|
gr.Image(type="pil", label="Output Mask"), |
|
], |
|
live=True |
|
) |
|
|
|
|
|
iface.launch() |
|
api_interface.launch(share=True) |
|
|