ancebuc's picture
Update app.py
dc599f7 verified
raw
history blame
1.3 kB
from PIL import Image
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import io
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def visualize_segmentation(image, prompts, preds):
fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
return Image.open(buf)
def segment(img, clases):
image = Image.fromarray(img, 'RGB')
prompts = clases.split(',')
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
return visualize_segmentation(image, prompts, preds)
demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="image")
demo.launch()