ancebuc's picture
Update app.py
2561975 verified
raw
history blame
1.27 kB
from PIL import Image
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import io
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def visualize_segmentation(image, prompts, preds):
fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
return Image.open(buf)
def segment(img, clases):
image = Image.fromarray(img, 'RGB')
prompts = clases.split(',')
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
return visualize_segmentation(image, prompts, preds)
demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="image")
demo.launch()