ShehryarAli's picture
Update app.py
a7f4cd7 verified
raw
history blame
2.39 kB
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm
import torch
from PIL import Image
import requests
from io import BytesIO
import gradio as gr
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
def replace(text):
# image = Image.open(text).convert("RGB")
inputs = processor(text, return_tensors="pt")
outputs = model(**inputs)
prediction = processor.post_process_panoptic_segmentation(outputs, target_sizes=[text.size[::-1]])[0]
return draw_panoptic_segmentation(**prediction)
def draw_panoptic_segmentation(segmentation, segments_info):
# get the used color map
viridis = cm.get_cmap('viridis', torch.max(segmentation))
fig, ax = plt.subplots()
ax.imshow(segmentation)
instances_counter = defaultdict(int)
handles = []
# for each segment, draw its legend
# for segment in segments_info:
# segment_id = segment['id']
# segment_label_id = segment['label_id']
# segment_label = model.config.id2label[segment_label_id]
# label = f"{segment_label}-{instances_counter[segment_label_id]}"
# instances_counter[segment_label_id] += 1
# color = viridis(segment_id)
# handles.append(mpatches.Patch(color=color, label=label))
# ax.legend(handles=handles)
for segment in segments_info:
segment_id = segment['id']
color = viridis(segment_id)
# Save the figure to a buffer and convert it to a PIL image
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close(fig) # Close the figure to free memory
pil_image = Image.open(buf)
return pil_image
# Set up the Gradio interface with updated syntax
interface = gr.Interface(
fn=replace, # The function to execute
inputs=gr.Image(type="pil"), # Input type as PIL image
outputs="image", # Output type as an image
title="Image Segmentation with Mask Overlay", # Title for the Gradio app
description="Upload an image to see the segmentation mask applied." # Description for the app
)
# Launch the Gradio app
interface.launch(debug=True)