from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from PIL import Image import torch from collections import defaultdict import matplotlib.pyplot as plt from matplotlib import cm import matplotlib.patches as mpatches import os import numpy as np import argparse import matplotlib import gradio as gr def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): if type(image_path) is str: image = np.array(Image.open(image_path))[:, :, :3] else: image = image_path h, w, c = image.shape left = min(left, w-1) right = min(right, w - left - 1) top = min(top, h - left - 1) bottom = min(bottom, h - top - 1) image = image[top:h-bottom, left:w-right] h, w, c = image.shape if h < w: offset = (w - h) // 2 image = image[:, offset:offset + h] elif w < h: offset = (h - w) // 2 image = image[offset:offset + w] image = np.array(Image.fromarray(image).resize((size, size))) return image def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False, model =None): if torch.max(segmentation)==torch.min(segmentation)==-1: print("nothing is detected!") noseg=True viridis = matplotlib.colormaps['viridis'].resampled(1) else: viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1) fig, ax = plt.subplots() ax.imshow(segmentation) instances_counter = defaultdict(int) handles = [] label_list = [] mask_list = [] if not noseg: if torch.min(segmentation) == 0: mask = segmentation==0 mask = mask.cpu().detach() # [512,512] bool segment_label = "rest" color = viridis(0) label = f"{segment_label}-{0}" mask_list.append(mask) handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) for segment in segments_info: segment_id = segment['id'] mask = segmentation==segment_id if torch.min(segmentation) != 0: segment_id -= 1 mask = mask.cpu().detach() # [512,512] bool mask_list.append(mask) segment_label = model.config.id2label[segment['label_id']] instances_counter[segment['label_id']] += 1 color = viridis(segment_id) label = f"{segment_label}-{segment_id}" handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) else: mask = torch.from_numpy(np.full(segmentation.shape, True)) segment_label = "all" mask_list.append(mask) color = viridis(0) label = f"{segment_label}-{0}" handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) plt.xticks([]) plt.yticks([]) # plt.savefig(os.path.join(save_folder, 'mask_clear.png'), dpi=500) ax.legend(handles=handles) plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 ) print("; ".join(label_list)) return mask_list,label_list def run_segmentation(image, name="example_tmp", size = 512, noseg=False): base_folder_path = "." processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") # input_folder = os.path.join(base_folder_path, name ) # try: # image = load_image(os.path.join(input_folder, "img.png" ), size = size) # except: # image = load_image(os.path.join(input_folder, "img.jpg" ), size = size) image =Image.fromarray(image) image = image.resize((size, size)) os.makedirs(name, exist_ok=True) image.save(os.path.join(name,"img_{}.png".format(size))) inputs = processor(image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] save_folder = os.path.join(base_folder_path, name) os.makedirs(save_folder, exist_ok=True) mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model) print("Finish segment") #block_flag += 1 return mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)