File size: 2,269 Bytes
356933b
 
 
 
 
 
 
 
 
2a95eb0
356933b
 
2a95eb0
 
356933b
 
 
 
 
2a95eb0
356933b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a95eb0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm
import torch
from PIL import Image
import requests
from io import BytesIO
import gradio as gr

processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")

def replace(text):
  # image = Image.open(text).convert("RGB")
  inputs = processor(text, return_tensors="pt")
  outputs = model(**inputs)
  prediction = processor.post_process_panoptic_segmentation(outputs, target_sizes=[text.size[::-1]])[0]
 
  return draw_panoptic_segmentation(**prediction)
def draw_panoptic_segmentation(segmentation, segments_info):
    # get the used color map
    viridis = cm.get_cmap('viridis', torch.max(segmentation))
    fig, ax = plt.subplots()
    ax.imshow(segmentation)
    instances_counter = defaultdict(int)
    handles = []
    # for each segment, draw its legend
    for segment in segments_info:
        segment_id = segment['id']
        segment_label_id = segment['label_id']
        segment_label = model.config.id2label[segment_label_id]
        label = f"{segment_label}-{instances_counter[segment_label_id]}"
        instances_counter[segment_label_id] += 1
        color = viridis(segment_id)
        handles.append(mpatches.Patch(color=color, label=label))
        
    ax.legend(handles=handles)
    
    # Save the figure to a buffer and convert it to a PIL image
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.close(fig)  # Close the figure to free memory
    
    pil_image = Image.open(buf)
    return pil_image
# Set up the Gradio interface with updated syntax
interface = gr.Interface(
    fn=replace,  # The function to execute
    inputs=gr.Image(type="pil"),  # Input type as PIL image
    outputs="image",  # Output type as an image
    title="Image Segmentation with Mask Overlay",  # Title for the Gradio app
    description="Upload an image to see the segmentation mask applied."  # Description for the app
)

# Launch the Gradio app
interface.launch(debug=True)