Spaces:
Sleeping
Sleeping
from PIL import Image | |
import requests | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import io | |
import gradio as gr | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
def visualize_segmentation(image, prompts, preds): | |
fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4)) | |
[a.axis('off') for a in ax.flatten()] | |
ax[0].imshow(image) | |
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))]; | |
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)]; | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png') | |
plt.close(fig) | |
return Image.open(buf) | |
def segment(img, clases): | |
image = Image.fromarray(img, 'RGB') | |
prompts = clases.split(',') | |
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
preds = outputs.logits.unsqueeze(1) | |
return visualize_segmentation(image, prompts, preds) | |
demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="image") | |
demo.launch() | |