File size: 3,769 Bytes
c38a591 a988bac 8de291e b95fb21 a988bac 8de291e a988bac c38a591 a988bac 8de291e 5280321 8de291e b95fb21 8de291e b95fb21 8de291e b95fb21 8de291e 4a7ca2e ba76dd6 b95fb21 8de291e b95fb21 8de291e b95fb21 8de291e a988bac 8de291e a988bac 8de291e a988bac 8de291e b95fb21 8de291e a988bac 8de291e a988bac 8de291e ef7902e 8de291e ef7902e c38a591 ef7902e 8de291e 37451d1 8de291e 9b95409 b95fb21 8de291e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
#
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO
import requests
#
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
# urls = ["http://farm3.staticflickr.com/2523/3705549787_79049b1b6d_z.jpg",
# "http://farm8.staticflickr.com/7012/6476201279_52db36af64_z.jpg",
# "http://farm8.staticflickr.com/7180/6967423255_a3d65d5f6b_z.jpg",
# "http://farm4.staticflickr.com/3563/3470840644_3378804bea_z.jpg",
# "http://farm9.staticflickr.com/8388/8516454091_0ebdc1130a_z.jpg"]
# images = []
# for i in urls:
# images.append(Image.open(requests.get(i, stream=True).raw))
# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
labels_list = []
with open("labels.txt", "r") as fp:
for line in fp:
labels_list.append(line[:-1])
colormap = np.asarray([
[131, 162, 255],
[180, 189, 255],
[255, 227, 187],
[255, 210, 143],
[248, 117, 170],
[255, 223, 223],
[255, 246, 246],
[174, 222, 252],
[150, 194, 145],
[255, 219, 170],
[244, 238, 238],
[50, 38, 83],
[128, 98, 214],
[146, 136, 248],
[255, 210, 215],
[255, 152, 152],
[162, 103, 138],
[63, 29, 56]
])
# with open(r"labels.txt", "r") as fp:
# for line in fp:
# labels_list.append(line[:-1])
#
# colormap = np.asarray(my_palette())
def greet(input_img):
input_img = Image.open(BytesIO(input_img))
inputs = feature_extractor(images=input_img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
logits = logits.detach().numpy()
logits = tf.transpose(logits.detach(), [0, 2, 3, 1]).numpy()
logits = tf.image.resize(logits, input_img.size[::-1])
seg = tf.math.argmax(logits, axis=-1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(colormap):
color_seg[seg.numpy() == label, :] = color
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
# Draw the plot
fig = draw_plot(pred_img, seg)
return fig
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis("off")
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg.numpy().astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
iface = gr.Interface(
fn=greet,
inputs="image",
outputs=["plot"],
examples=["image (1).jpg", "image (2).jpg", "image (3).jpg", "image (4).jpg", "image (5).jpg"],
allow_flagging="never"
)
iface.launch(share=True)
|