ancebuc's picture
Update app.py
98d8338 verified
from PIL import Image
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import io
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def visualize_segmentation(image, prompts, preds):
fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
return Image.open(buf)
def segment(img, clases):
image = Image.fromarray(img, 'RGB')
prompts = clases.split(',')
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
return visualize_segmentation(image, prompts, preds)
demo = gr.Interface(fn = segment,
inputs = ["image", gr.Textbox(label = 'Enter classes separated by ","')],
outputs = "image",
examples = [['desayuno.jpg', 'cutlery, pancakes, blueberries, orange juice']]
)
demo.launch()