from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation from PIL import Image import torch from collections import defaultdict import matplotlib.pyplot as plt from matplotlib import cm import matplotlib.patches as mpatches import os import numpy as np import argparse import matplotlib import gradio as gr import cv2 def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): if type(image_path) is str: image = np.array(Image.open(image_path))[:, :, :3] else: image = image_path h, w, c = image.shape left = min(left, w-1) right = min(right, w - left - 1) top = min(top, h - left - 1) bottom = min(bottom, h - top - 1) image = image[top:h-bottom, left:w-right] h, w, c = image.shape if h < w: offset = (w - h) // 2 image = image[:, offset:offset + h] elif w < h: offset = (h - w) // 2 image = image[offset:offset + w] image = np.array(Image.fromarray(image).resize((size, size))) return image def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False, model =None): if torch.max(segmentation)==torch.min(segmentation)==-1: print("nothing is detected!") noseg=True viridis = matplotlib.colormaps['viridis'].resampled(1) else: viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1) fig, ax = plt.subplots() ax.imshow(segmentation) instances_counter = defaultdict(int) handles = [] label_list = [] mask_np_list = [] if not noseg: if torch.min(segmentation) == 0: mask = segmentation==0 mask = mask.cpu().detach().numpy() # [512,512] bool print(mask.shape) mask = cv2.resize(mask,(512,512)) segment_label = "rest" color = viridis(0) label = f"{segment_label}-{0}" mask_np_list.append(mask) handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) for segment in segments_info: segment_id = segment['id'] mask = segmentation==segment_id if torch.min(segmentation) != 0: segment_id -= 1 mask = mask.cpu().detach().numpy() # [512,512] bool print(mask.shape) mask = cv2.resize(mask,(512,512)) mask_np_list.append(mask) segment_label = model.config.id2label[segment['label_id']] instances_counter[segment['label_id']] += 1 color = viridis(segment_id) label = f"{segment_label}-{segment_id}" handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) else: mask = np.full(segmentation.shape, True) print(mask.shape) mask = cv2.resize(mask,(512,512)) segment_label = "all" mask_np_list.append(mask) color = viridis(0) label = f"{segment_label}-{0}" handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) plt.xticks([]) plt.yticks([]) # plt.savefig(os.path.join(save_folder, 'mask_clear.png'), dpi=500) ax.legend(handles=handles) plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 ) print("; ".join(label_list)) return mask_np_list,label_list def run_segmentation(image, name="example_tmp", size = 512, noseg=False): base_folder_path = "." processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") # input_folder = os.path.join(base_folder_path, name ) # try: # image = load_image(os.path.join(input_folder, "img.png" ), size = size) # except: # image = load_image(os.path.join(input_folder, "img.jpg" ), size = size) image =Image.fromarray(image) image = image.resize((size, size)) os.makedirs(name, exist_ok=True) #image.save(os.path.join(name,"img_{}.png".format(size))) inputs = processor(image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] save_folder = os.path.join(base_folder_path, name) os.makedirs(save_folder, exist_ok=True) mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model) print("Finish segment") #block_flag += 1 return image,mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)