Hantr commited on
Commit
6d1120a
ยท
1 Parent(s): ba0292c
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from matplotlib import gridspec
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
 
@@ -58,7 +58,7 @@ def label_to_color_image(label):
58
  return colormap[label]
59
 
60
 
61
- def draw_plot(pred_img, seg):
62
  fig = plt.figure(figsize=(20, 15))
63
 
64
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
@@ -66,6 +66,8 @@ def draw_plot(pred_img, seg):
66
  plt.subplot(grid_spec[0])
67
  plt.imshow(pred_img)
68
  plt.axis('off')
 
 
69
  LABEL_NAMES = np.asarray(labels_list)
70
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
71
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
@@ -77,6 +79,18 @@ def draw_plot(pred_img, seg):
77
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
78
  plt.xticks([], [])
79
  ax.tick_params(width=0.0, labelsize=25)
 
 
 
 
 
 
 
 
 
 
 
 
80
  return fig
81
 
82
 
@@ -91,7 +105,6 @@ def sepia(input_img):
91
  logits, input_img.size[::-1]
92
  ) # We reverse the shape of `image` because `image.size` returns width and height.
93
  seg = tf.math.argmax(logits, axis=-1)[0]
94
-
95
  color_seg = np.zeros(
96
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
97
  ) # height, width, 3
@@ -102,7 +115,7 @@ def sepia(input_img):
102
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
103
  pred_img = pred_img.astype(np.uint8)
104
 
105
- fig = draw_plot(pred_img, seg)
106
  return fig
107
 
108
 
 
3
  from matplotlib import gridspec
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ from PIL import Image, ImageDraw
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
 
 
58
  return colormap[label]
59
 
60
 
61
+ def draw_plot_with_labels(pred_img, seg):
62
  fig = plt.figure(figsize=(20, 15))
63
 
64
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
66
  plt.subplot(grid_spec[0])
67
  plt.imshow(pred_img)
68
  plt.axis('off')
69
+
70
+ # ๋ผ๋ฒจ ์ด๋ฆ„์„ ์ถ”๊ฐ€ํ•˜๊ธฐ ์œ„ํ•œ ์ฝ”๋“œ
71
  LABEL_NAMES = np.asarray(labels_list)
72
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
73
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
 
79
  plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
80
  plt.xticks([], [])
81
  ax.tick_params(width=0.0, labelsize=25)
82
+
83
+ # ๋ผ๋ฒจ ์ด๋ฆ„ ํ…์ŠคํŠธ ์ถ”๊ฐ€
84
+ draw = ImageDraw.Draw(pred_img)
85
+ for label, color in enumerate(colormap):
86
+ mask = seg.numpy() == label
87
+ if np.any(mask):
88
+ y, x = np.where(mask)
89
+ y = np.mean(y).astype(int)
90
+ x = np.mean(x).astype(int)
91
+ label_name = LABEL_NAMES[label]
92
+ draw.text((x, y), label_name, fill=tuple(color), fontsize=20)
93
+
94
  return fig
95
 
96
 
 
105
  logits, input_img.size[::-1]
106
  ) # We reverse the shape of `image` because `image.size` returns width and height.
107
  seg = tf.math.argmax(logits, axis=-1)[0]
 
108
  color_seg = np.zeros(
109
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
110
  ) # height, width, 3
 
115
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
116
  pred_img = pred_img.astype(np.uint8)
117
 
118
+ fig = draw_plot_with_labels(pred_img, seg)
119
  return fig
120
 
121