File size: 3,736 Bytes
c38a591 a988bac 8de291e a988bac 8de291e a988bac 8de291e a988bac c38a591 a988bac 8de291e 5280321 8de291e a988bac 8de291e a988bac 8de291e a988bac 8de291e a988bac 8de291e a988bac 8de291e ef7902e 8de291e ef7902e c38a591 ef7902e 8de291e 37451d1 8de291e 5280321 8de291e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
#
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from matplotlib import gridspec
from PIL import Image
import numpy as np
import tensorflow as tf
import requests
#
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
# urls = ["http://farm3.staticflickr.com/2523/3705549787_79049b1b6d_z.jpg",
# "http://farm8.staticflickr.com/7012/6476201279_52db36af64_z.jpg",
# "http://farm8.staticflickr.com/7180/6967423255_a3d65d5f6b_z.jpg",
# "http://farm4.staticflickr.com/3563/3470840644_3378804bea_z.jpg",
# "http://farm9.staticflickr.com/8388/8516454091_0ebdc1130a_z.jpg"]
# images = []
# for i in urls:
# images.append(Image.open(requests.get(i, stream=True).raw))
# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
def my_palette():
return [
[131, 162, 255],
[180, 189, 255],
[255, 227, 187],
[255, 210, 143],
[248, 117, 170],
[255, 223, 223],
[255, 246, 246],
[174, 222, 252],
[150, 194, 145],
[255, 219, 170],
[244, 238, 238],
[50, 38, 83],
[128, 98, 214],
[146, 136, 248],
[255, 210, 215],
[255, 152, 152],
[162, 103, 138],
[63, 29, 56]
]
labels_list = []
with open(r"labels.txt", "r") as fp:
for line in fp:
labels_list.append(line[:-1])
colormap = np.asarray(my_palette())
def greet(input_img):
inputs = feature_extractor(images=input_img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
logits = tf.transpose(logits, [0, 2, 3, 1])
logits = tf.image.resize(
logits, input_img.size[::-1]
) # We reverse the shape of `image` because `image.size` returns width and height.
seg = tf.math.argmax(logits, axis=-1)[0]
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
for label, color in enumerate(colormap):
color_seg[seg.numpy() == label, :] = color
# Show image + mask
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis("off")
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg.numpy().astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
iface = gr.Interface(
fn=greet,
inputs="image",
outputs=["plot"],
examples=["image (1)", "image (2)", "image (3)", "image (4)", "image (5)"],
allow_flagging="never")
iface.launch(share=True)
|