CLIPSeg2 / app.py
sigyllly's picture
Update app.py
3991df5 verified
raw
history blame
2.71 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import numpy as np
import threading
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Function to process image and generate mask
# Function to process image and generate mask
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
# Extract image tensor and normalize it
image_tensor = inputs["pixel_values"].squeeze().permute(1, 2, 0).cpu().numpy()
image_tensor = (image_tensor * 255).astype(np.uint8)
image_tensor = Image.fromarray(image_tensor)
image_tensor = image_tensor.convert("RGB")
# Perform CLIPSeg processing
inputs = processor(
text=prompt, images=image_tensor, padding="max_length", return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
# Function to extract image using positive and negative prompts
def extract_image(pos_prompts, neg_prompts, img, threshold):
positive_masks = get_masks(pos_prompts, img, 0.5)
negative_masks = get_masks(neg_prompts, img, 0.5)
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask
# Define Gradio interface
iface = gr.Interface(
fn=extract_image,
inputs=[
gr.Textbox(
label="Please describe what you want to identify (comma separated)",
key="pos_prompts",
),
gr.Textbox(
label="Please describe what you want to ignore (comma separated)",
key="neg_prompts",
),
gr.Image(type="pil", label="Input Image", key="img"),
gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold", key="threshold"),
],
outputs=[
gr.Image(label="Result", key="output_image"),
gr.Image(label="Mask", key="output_mask"),
],
)
# Launch Gradio API
iface.launch()