Spaces:
Sleeping
Sleeping
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
from collections import defaultdict | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
from matplotlib import cm | |
import torch | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import gradio as gr | |
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
def replace(text): | |
# image = Image.open(text).convert("RGB") | |
inputs = processor(text, return_tensors="pt") | |
outputs = model(**inputs) | |
prediction = processor.post_process_panoptic_segmentation(outputs, target_sizes=[text.size[::-1]])[0] | |
return draw_panoptic_segmentation(**prediction) | |
def draw_panoptic_segmentation(segmentation, segments_info): | |
# get the used color map | |
viridis = cm.get_cmap('viridis', torch.max(segmentation)) | |
fig, ax = plt.subplots() | |
ax.imshow(segmentation) | |
instances_counter = defaultdict(int) | |
handles = [] | |
# for each segment, draw its legend | |
# for segment in segments_info: | |
# segment_id = segment['id'] | |
# segment_label_id = segment['label_id'] | |
# segment_label = model.config.id2label[segment_label_id] | |
# label = f"{segment_label}-{instances_counter[segment_label_id]}" | |
# instances_counter[segment_label_id] += 1 | |
# color = viridis(segment_id) | |
# handles.append(mpatches.Patch(color=color, label=label)) | |
# ax.legend(handles=handles) | |
for segment in segments_info: | |
segment_id = segment['id'] | |
color = viridis(segment_id) | |
# Save the figure to a buffer and convert it to a PIL image | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
plt.close(fig) # Close the figure to free memory | |
pil_image = Image.open(buf) | |
return pil_image | |
# Set up the Gradio interface with updated syntax | |
interface = gr.Interface( | |
fn=replace, # The function to execute | |
inputs=gr.Image(type="pil"), # Input type as PIL image | |
outputs="image", # Output type as an image | |
title="Image Segmentation with Mask Overlay", # Title for the Gradio app | |
description="Upload an image to see the segmentation mask applied." # Description for the app | |
) | |
# Launch the Gradio app | |
interface.launch(debug=True) |