ancebuc's picture
Update app.py
cf253ac verified
raw
history blame
1.13 kB
from PIL import Image
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def visualize_segmentation(image, prompts, preds):
_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
def segment(img, clases):
print(img)
image = Image.fromarray(img, 'RGB')
prompts = clases.split(',')
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
return "Hello " + prompts + "!!"
demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="text")
demo.launch()