from PIL import Image import requests import torch import matplotlib.pyplot as plt import numpy as np import gradio as gr from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def visualize_segmentation(image, prompts, preds): _, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4)) [a.axis('off') for a in ax.flatten()] ax[0].imshow(image) [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))]; [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)]; def segment(img, clases): prompts = clases.split(',') inputs = processor(text=prompts, images=[image] * len(img), padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) return "Hello " + prompts + "!!" demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="text") demo.launch()