from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from collections import defaultdict import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import cm import torch from PIL import Image import requests from io import BytesIO import gradio as gr processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") def replace(text): # image = Image.open(text).convert("RGB") inputs = processor(text, return_tensors="pt") outputs = model(**inputs) prediction = processor.post_process_panoptic_segmentation(outputs, target_sizes=[text.size[::-1]])[0] return draw_panoptic_segmentation(**prediction) def draw_panoptic_segmentation(segmentation, segments_info): # get the used color map viridis = cm.get_cmap('viridis', torch.max(segmentation)) fig, ax = plt.subplots() ax.imshow(segmentation) instances_counter = defaultdict(int) handles = [] # for each segment, draw its legend # for segment in segments_info: # segment_id = segment['id'] # segment_label_id = segment['label_id'] # segment_label = model.config.id2label[segment_label_id] # label = f"{segment_label}-{instances_counter[segment_label_id]}" # instances_counter[segment_label_id] += 1 # color = viridis(segment_id) # handles.append(mpatches.Patch(color=color, label=label)) # ax.legend(handles=handles) for segment in segments_info: segment_id = segment['id'] color = viridis(segment_id) # Save the figure to a buffer and convert it to a PIL image buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close(fig) # Close the figure to free memory pil_image = Image.open(buf) return pil_image # Set up the Gradio interface with updated syntax interface = gr.Interface( fn=replace, # The function to execute inputs=gr.Image(type="pil"), # Input type as PIL image outputs="image", # Output type as an image title="Image Segmentation with Mask Overlay", # Title for the Gradio app description="Upload an image to see the segmentation mask applied." # Description for the app ) # Launch the Gradio app interface.launch(debug=True)