22 / app.py
hyo37009's picture
a
ba76dd6
raw
history blame
3.77 kB
import gradio as gr
#
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO
import requests
#
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
# urls = ["http://farm3.staticflickr.com/2523/3705549787_79049b1b6d_z.jpg",
# "http://farm8.staticflickr.com/7012/6476201279_52db36af64_z.jpg",
# "http://farm8.staticflickr.com/7180/6967423255_a3d65d5f6b_z.jpg",
# "http://farm4.staticflickr.com/3563/3470840644_3378804bea_z.jpg",
# "http://farm9.staticflickr.com/8388/8516454091_0ebdc1130a_z.jpg"]
# images = []
# for i in urls:
# images.append(Image.open(requests.get(i, stream=True).raw))
# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
labels_list = []
with open("labels.txt", "r") as fp:
for line in fp:
labels_list.append(line[:-1])
colormap = np.asarray([
[131, 162, 255],
[180, 189, 255],
[255, 227, 187],
[255, 210, 143],
[248, 117, 170],
[255, 223, 223],
[255, 246, 246],
[174, 222, 252],
[150, 194, 145],
[255, 219, 170],
[244, 238, 238],
[50, 38, 83],
[128, 98, 214],
[146, 136, 248],
[255, 210, 215],
[255, 152, 152],
[162, 103, 138],
[63, 29, 56]
])
# with open(r"labels.txt", "r") as fp:
# for line in fp:
# labels_list.append(line[:-1])
#
# colormap = np.asarray(my_palette())
def greet(input_img):
input_img = Image.open(BytesIO(input_img))
inputs = feature_extractor(images=input_img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
logits = logits.detach().numpy()
logits = tf.transpose(logits.detach(), [0, 2, 3, 1]).numpy()
logits = tf.image.resize(logits, input_img.size[::-1])
seg = tf.math.argmax(logits, axis=-1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(colormap):
color_seg[seg.numpy() == label, :] = color
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
# Draw the plot
fig = draw_plot(pred_img, seg)
return fig
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis("off")
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg.numpy().astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
iface = gr.Interface(
fn=greet,
inputs="image",
outputs=["plot"],
examples=["image (1).jpg", "image (2).jpg", "image (3).jpg", "image (4).jpg", "image (5).jpg"],
allow_flagging="never"
)
iface.launch(share=True)