File size: 3,140 Bytes
f7c8faa
 
 
 
 
 
 
1003f47
f7c8faa
 
 
a621450
 
 
f7c8faa
 
 
 
 
 
 
 
 
 
a621450
f7c8faa
 
 
a621450
 
 
 
f7c8faa
a621450
 
 
 
f7c8faa
 
5b447e8
a621450
 
 
5b447e8
a621450
 
 
5b447e8
a621450
f7c8faa
5b447e8
 
 
 
 
f7c8faa
3355e69
 
f7c8faa
a621450
f7c8faa
a621450
 
f7c8faa
a621450
f7c8faa
a621450
 
f7c8faa
 
a621450
f7c8faa
 
 
 
 
 
 
 
 
 
 
78fdd4b
d00edc2
e7309a2
f7c8faa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import glob
import torch
import pickle
from PIL import Image, ImageDraw
import numpy as np
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation

from scipy.ndimage import center_of_mass

def combine_ims(im1, im2, val=128):
    p = Image.new("L", im1.size, val)
    im = Image.composite(im1, im2, p)
    return im

def get_class_centers(segmentation_mask, class_dict):
    segmentation_mask = segmentation_mask.numpy() + 1
    class_centers = {}
    for class_index, _ in class_dict.items():
        class_mask = (segmentation_mask == class_index).astype(int)
        center_of_mass_list = center_of_mass(class_mask)
        
        class_centers[class_index] = center_of_mass_list
    
    class_centers = {k: list(map(int, v)) for k, v in class_centers.items() if not np.isnan(sum(v))}
    return class_centers

def visualize_mask(predicted_semantic_map, class_ids, class_colors):
    h, w = predicted_semantic_map.shape
    color_indexes = np.zeros((h, w), dtype=np.uint8)
    color_indexes[:] = predicted_semantic_map.numpy()
    color_indexes = color_indexes.flatten()

    colors = class_colors[class_ids[color_indexes]]
    output = colors.reshape(h, w, 3).astype(np.uint8)
    image_mask = Image.fromarray(output)
    return image_mask

def get_out_image(image, predicted_semantic_map):
    class_centers = get_class_centers(predicted_semantic_map, class_dict)
    mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
    image_mask = combine_ims(image, mask, val=128)
    draw = ImageDraw.Draw(image_mask)

    extracted_tags = []
    for id, (y, x) in class_centers.items():
        class_name = str(class_names[id - 1])
        extracted_tags.append(class_name)  # Append only the class name
        draw.text((x, y), class_name, fill='black')

    # Joining all tags into a single string separated by " | "
    tags_string = " | ".join(extracted_tags)

    return image_mask, tags_string




def gradio_process(image):
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

    out_image, extracted_tags = get_out_image(image, predicted_semantic_map)
    return out_image, extracted_tags

with open('ade20k_classes.pickle', 'rb') as f:
    class_names, class_ids, class_colors = pickle.load(f)
class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
class_dict = dict(zip(class_ids, class_names))

device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device)
model.eval()

demo = gr.Interface(
    gradio_process, 
    inputs=gr.inputs.Image(type="pil"), 
    outputs=[gr.outputs.Image(type="pil"), gr.outputs.Textbox()],
    title="Semantic Segmentation",
    examples=glob.glob('./examples/*.jpg'),
    allow_flagging="never",
)

demo.launch()