from PIL import Image import requests import torch import matplotlib.pyplot as plt import numpy as np import io import gradio as gr from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def visualize_segmentation(image, prompts, preds): fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4)) [a.axis('off') for a in ax.flatten()] ax[0].imshow(image) [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))]; [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)]; buf = io.BytesIO() fig.savefig(buf, format='png') plt.close(fig) return Image.open(buf) def segment(img, clases): image = Image.fromarray(img, 'RGB') prompts = clases.split(',') inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) return visualize_segmentation(image, prompts, preds) demo = gr.Interface(fn=segment, inputs=["image", gr.Textbox(label = 'Enter classes separated by ","')], outputs="image", examples=['desayuno.jpg', 'cutlery, pancakes, blueberries, orange juice'] ) demo.launch()