File size: 1,297 Bytes
399ea10
 
 
 
 
 
 
2561975
399ea10
a49c2ff
 
399ea10
 
 
 
 
 
2561975
 
 
 
 
 
 
 
 
 
 
399ea10
 
 
cf253ac
399ea10
 
46314d9
dc599f7
399ea10
dc599f7
 
399ea10
 
2561975
a49c2ff
2561975
a49c2ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from PIL import Image
import requests

import torch

import matplotlib.pyplot as plt
import numpy as np
import io

import gradio as gr

from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def visualize_segmentation(image, prompts, preds):
    fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
    [a.axis('off') for a in ax.flatten()]
    ax[0].imshow(image)
    [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
    [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];

    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close(fig)
    
    return Image.open(buf)
     

def segment(img, clases):
    image = Image.fromarray(img, 'RGB')
    prompts = clases.split(',')

    inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")

    
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)

    return visualize_segmentation(image, prompts, preds)

demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="image")
demo.launch()