Hantr commited on
Commit
48ecfaa
ยท
1 Parent(s): de591bd
Files changed (1) hide show
  1. app.py +23 -46
app.py CHANGED
@@ -48,6 +48,7 @@ with open(r'labels.txt', 'r') as fp:
48
 
49
  colormap = np.asarray(ade_palette())
50
 
 
51
  def label_to_color_image(label):
52
  if label.ndim != 2:
53
  raise ValueError("Expect 2-D input label")
@@ -57,26 +58,14 @@ def label_to_color_image(label):
57
  return colormap[label]
58
 
59
 
60
- def draw_plot(pred_img, seg):
61
- fig = plt.figure(figsize=(20, 15))
62
-
63
- grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
64
 
65
- plt.subplot(grid_spec[0])
66
- plt.imshow(pred_img)
67
- plt.axis('off')
68
- LABEL_NAMES = np.asarray(labels_list)
69
- FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
70
- FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
71
 
72
- unique_labels = np.unique(seg.numpy().astype("uint8"))
73
- ax = plt.subplot(grid_spec[1])
74
- plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
75
- ax.yaxis.tick_right()
76
- plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
77
- plt.xticks([], [])
78
- ax.tick_params(width=0.0, labelsize=25)
79
- return fig
80
 
81
  def sepia(input_img):
82
  input_img = Image.fromarray(input_img)
@@ -88,43 +77,31 @@ def sepia(input_img):
88
  logits = tf.transpose(logits, [0, 2, 3, 1])
89
  logits = tf.image.resize(
90
  logits, input_img.size[::-1]
91
- ) # We reverse the shape of `image` because `image.size` returns width and height.
92
  seg = tf.math.argmax(logits, axis=-1)[0]
93
 
94
- color_seg = np.zeros(
95
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
96
- ) # height, width, 3
97
- for label, color in enumerate(colormap):
98
- color_seg[seg.numpy() == label, :] = color
99
-
100
- # Show image + mask
101
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
102
- pred_img = pred_img.astype(np.uint8)
103
 
104
- fig = draw_plot(pred_img, seg)
105
- return fig
106
 
107
 
108
- def segment_image(input_img):
109
- input_img = Image.fromarray(input_img)
110
-
111
- # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
112
- inputs = feature_extractor(images=input_img, return_tensors="tf")
113
- outputs = model(**inputs)
114
- logits = outputs.logits
115
-
116
- logits = tf.transpose(logits, [0, 2, 3, 1])
117
- logits = tf.image.resize(
118
- logits, input_img.size[::-1]
119
- )
120
- seg = tf.math.argmax(logits, axis=-1)[0]
121
 
122
- return input_img, seg
 
 
 
 
 
123
 
124
 
125
- demo = gr.Interface(fn=segment_image,
126
  inputs=gr.Image(shape=(1024, 1024)),
127
- outputs=["image", "image"],
128
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg", "city-4.jpg", "city-5.jpg"],
129
  allow_flagging='never')
130
 
 
48
 
49
  colormap = np.asarray(ade_palette())
50
 
51
+
52
  def label_to_color_image(label):
53
  if label.ndim != 2:
54
  raise ValueError("Expect 2-D input label")
 
58
  return colormap[label]
59
 
60
 
61
+ def draw_class_visualization(seg, class_id):
62
+ class_mask = seg.numpy() == class_id
63
+ class_color = colormap[class_id]
 
64
 
65
+ class_visualization = np.zeros(seg.shape + (3,))
66
+ class_visualization[class_mask] = class_color
 
 
 
 
67
 
68
+ return class_visualization
 
 
 
 
 
 
 
69
 
70
  def sepia(input_img):
71
  input_img = Image.fromarray(input_img)
 
77
  logits = tf.transpose(logits, [0, 2, 3, 1])
78
  logits = tf.image.resize(
79
  logits, input_img.size[::-1]
80
+ )
81
  seg = tf.math.argmax(logits, axis=-1)[0]
82
 
83
+ class_visualizations = []
84
+ for class_id in range(len(colormap)):
85
+ class_visualization = draw_class_visualization(seg, class_id)
86
+ class_visualizations.append(class_visualization)
 
 
 
 
 
87
 
88
+ return class_visualizations
 
89
 
90
 
91
+ def plot_class_visualization(class_visualizations):
92
+ fig, axes = plt.subplots(1, len(class_visualizations), figsize=(20, 15))
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ for i, class_visualization in enumerate(class_visualizations):
95
+ ax = axes[i]
96
+ ax.imshow(class_visualization)
97
+ ax.axis('off')
98
+ ax.set_title(labels_list[i])
99
+ return fig
100
 
101
 
102
+ demo = gr.Interface(fn=sepia,
103
  inputs=gr.Image(shape=(1024, 1024)),
104
+ outputs=gr.outputs.Image(type='plot', label="Class Visualizations"),
105
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg", "city-4.jpg", "city-5.jpg"],
106
  allow_flagging='never')
107