GoBoKyung commited on
Commit
736ffb2
·
1 Parent(s): 8a07d05
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -70,7 +70,7 @@ def draw_plot(pred_img, seg):
70
  ax = plt.subplot(grid_spec[1])
71
  plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
72
  ax.yaxis.tick_right()
73
- plt.yticks(range(len unique_labels), LABEL_NAMES[unique_labels]),
74
  plt.xticks([], [])
75
  ax.tick_params(width=0.0, labelsize=25)
76
  return fig
@@ -83,11 +83,14 @@ def sepia(input_img):
83
  logits = outputs.logits
84
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
- input_img_size = input_img.size[::-1]
87
- logits = tf.image.resize(logits, input_img_size)
 
88
  seg = tf.math.argmax(logits, axis=-1)[0]
89
 
90
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
 
 
91
  for label, color in enumerate(colormap):
92
  color_seg[seg.numpy() == label, :] = color
93
 
@@ -99,9 +102,10 @@ def sepia(input_img):
99
  return fig
100
 
101
  demo = gr.Interface(fn=sepia,
102
- inputs=gr.Image(type="pil", label="Upload an Image"),
103
- outputs="plot",
104
  examples=["img_1.jpg", "img_2.jpeg", "img_3.jpg", "img_4.jpg", "img_5.png"],
105
  allow_flagging='never')
106
 
 
107
  demo.launch()
 
70
  ax = plt.subplot(grid_spec[1])
71
  plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
72
  ax.yaxis.tick_right()
73
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
74
  plt.xticks([], [])
75
  ax.tick_params(width=0.0, labelsize=25)
76
  return fig
 
83
  logits = outputs.logits
84
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
+ logits = tf.image.resize(
87
+ logits, input_img.size[::-1]
88
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
89
  seg = tf.math.argmax(logits, axis=-1)[0]
90
 
91
+ color_seg = np.zeros(
92
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
93
+ ) # height, width, 3
94
  for label, color in enumerate(colormap):
95
  color_seg[seg.numpy() == label, :] = color
96
 
 
102
  return fig
103
 
104
  demo = gr.Interface(fn=sepia,
105
+ inputs=gr.Image(shape=(400, 600)),
106
+ outputs=['plot'],
107
  examples=["img_1.jpg", "img_2.jpeg", "img_3.jpg", "img_4.jpg", "img_5.png"],
108
  allow_flagging='never')
109
 
110
+
111
  demo.launch()