Spaces:
Runtime error
Runtime error
File size: 3,140 Bytes
f7c8faa 1003f47 f7c8faa a621450 f7c8faa a621450 f7c8faa a621450 f7c8faa a621450 f7c8faa 5b447e8 a621450 5b447e8 a621450 5b447e8 a621450 f7c8faa 5b447e8 f7c8faa 3355e69 f7c8faa a621450 f7c8faa a621450 f7c8faa a621450 f7c8faa a621450 f7c8faa a621450 f7c8faa 78fdd4b d00edc2 e7309a2 f7c8faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
import glob
import torch
import pickle
from PIL import Image, ImageDraw
import numpy as np
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from scipy.ndimage import center_of_mass
def combine_ims(im1, im2, val=128):
p = Image.new("L", im1.size, val)
im = Image.composite(im1, im2, p)
return im
def get_class_centers(segmentation_mask, class_dict):
segmentation_mask = segmentation_mask.numpy() + 1
class_centers = {}
for class_index, _ in class_dict.items():
class_mask = (segmentation_mask == class_index).astype(int)
center_of_mass_list = center_of_mass(class_mask)
class_centers[class_index] = center_of_mass_list
class_centers = {k: list(map(int, v)) for k, v in class_centers.items() if not np.isnan(sum(v))}
return class_centers
def visualize_mask(predicted_semantic_map, class_ids, class_colors):
h, w = predicted_semantic_map.shape
color_indexes = np.zeros((h, w), dtype=np.uint8)
color_indexes[:] = predicted_semantic_map.numpy()
color_indexes = color_indexes.flatten()
colors = class_colors[class_ids[color_indexes]]
output = colors.reshape(h, w, 3).astype(np.uint8)
image_mask = Image.fromarray(output)
return image_mask
def get_out_image(image, predicted_semantic_map):
class_centers = get_class_centers(predicted_semantic_map, class_dict)
mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
image_mask = combine_ims(image, mask, val=128)
draw = ImageDraw.Draw(image_mask)
extracted_tags = []
for id, (y, x) in class_centers.items():
class_name = str(class_names[id - 1])
extracted_tags.append(class_name) # Append only the class name
draw.text((x, y), class_name, fill='black')
# Joining all tags into a single string separated by " | "
tags_string = " | ".join(extracted_tags)
return image_mask, tags_string
def gradio_process(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
out_image, extracted_tags = get_out_image(image, predicted_semantic_map)
return out_image, extracted_tags
with open('ade20k_classes.pickle', 'rb') as f:
class_names, class_ids, class_colors = pickle.load(f)
class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
class_dict = dict(zip(class_ids, class_names))
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device)
model.eval()
demo = gr.Interface(
gradio_process,
inputs=gr.inputs.Image(type="pil"),
outputs=[gr.outputs.Image(type="pil"), gr.outputs.Textbox()],
title="Semantic Segmentation",
examples=glob.glob('./examples/*.jpg'),
allow_flagging="never",
)
demo.launch() |