Karin0616 commited on
Commit
1f6afca
ยท
1 Parent(s): 86b082c
Files changed (1) hide show
  1. app.py +132 -49
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  from matplotlib import gridspec
3
  import matplotlib.pyplot as plt
4
  import numpy as np
@@ -14,26 +16,28 @@ model = TFSegformerForSemanticSegmentation.from_pretrained(
14
  )
15
 
16
  def ade_palette():
 
17
  return [
18
- [204, 87, 92], # road (Reddish)
19
  [112, 185, 212], # sidewalk (Blue)
20
  [196, 160, 122], # building (Brown)
21
  [106, 135, 242], # wall (Light Blue)
22
- [91, 192, 222], # fence (Turquoise)
23
  [255, 192, 203], # pole (Pink)
24
  [176, 224, 230], # traffic light (Light Blue)
25
- [222, 49, 99], # traffic sign (Red)
26
- [139, 69, 19], # vegetation (Brown)
27
- [255, 0, 0], # terrain (Red)
28
- [0, 0, 255], # sky (Blue)
29
  [255, 228, 181], # person (Peach)
30
- [128, 0, 0], # rider (Maroon)
31
- [0, 128, 0], # car (Green)
32
- [255, 99, 71], # truck (Tomato)
33
- [0, 255, 0], # bus (Lime)
34
- [128, 0, 128], # train (Purple)
35
- [255, 255, 0], # motorcycle (Yellow)
36
- [128, 0, 128] # bicycle (Purple)
 
37
  ]
38
 
39
  labels_list = []
@@ -73,14 +77,7 @@ def draw_plot(pred_img, seg):
73
  ax.tick_params(width=0.0, labelsize=25)
74
  return fig
75
 
76
- def sepia(input_img, *label_buttons):
77
- selected_color = None
78
- for label, button_state in zip(labels_list, label_buttons):
79
- if button_state:
80
- label_index = labels_list.index(label)
81
- selected_color = colormap[label_index]
82
- break
83
-
84
  input_img = Image.fromarray(input_img)
85
 
86
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -90,43 +87,129 @@ def sepia(input_img, *label_buttons):
90
  logits = tf.transpose(logits, [0, 2, 3, 1])
91
  logits = tf.image.resize(
92
  logits, input_img.size[::-1]
93
- )
94
  seg = tf.math.argmax(logits, axis=-1)[0]
95
 
96
  color_seg = np.zeros(
97
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
98
- )
99
- if selected_color:
100
- label = colormap.index(selected_color)
101
- color_seg[seg.numpy() == label, :] = selected_color
102
 
 
103
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
104
  pred_img = pred_img.astype(np.uint8)
105
 
106
  fig = draw_plot(pred_img, seg)
107
  return fig
108
 
109
- # ๋ผ๋ฒจ ๋ฒ„ํŠผ ์ƒ์„ฑ
110
- label_buttons = [gr.Button(label) for label in labels_list]
111
-
112
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
113
- iface = gr.Interface(
114
- fn=sepia,
115
- inputs=[gr.Image(shape=(564, 846))] + label_buttons,
116
- outputs="plot",
117
- live=True,
118
- examples=["city1.jpg", "city2.jpg", "city3.jpg"],
119
- allow_flagging='never',
120
- title="This is a machine learning activity project at Kyunggi University.",
121
- theme="darkpeach",
122
- css="""
123
- body{
124
- background-color: dark;
125
- color: white;
126
- font-family: Arial, sans-serif;
127
- }
128
- """
129
- )
130
 
131
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹œ์ž‘
132
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import random
3
+
4
  from matplotlib import gridspec
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
16
  )
17
 
18
  def ade_palette():
19
+
20
  return [
21
+ [204, 87, 92], # road (Reddish)
22
  [112, 185, 212], # sidewalk (Blue)
23
  [196, 160, 122], # building (Brown)
24
  [106, 135, 242], # wall (Light Blue)
25
+ [91, 192, 222], # fence (Turquoise)
26
  [255, 192, 203], # pole (Pink)
27
  [176, 224, 230], # traffic light (Light Blue)
28
+ [222, 49, 99], # traffic sign (Red)
29
+ [139, 69, 19], # vegetation (Brown)
30
+ [255, 0, 0], # terrain (Red)
31
+ [0, 0, 255], # sky (Blue)
32
  [255, 228, 181], # person (Peach)
33
+ [128, 0, 0], # rider (Maroon)
34
+ [0, 128, 0], # car (Green)
35
+ [255, 99, 71], # truck (Tomato)
36
+ [0, 255, 0], # bus (Lime)
37
+ [128, 0, 128], # train (Purple)
38
+ [255, 255, 0], # motorcycle (Yellow)
39
+ [128, 0, 128] # bicycle (Purple)
40
+
41
  ]
42
 
43
  labels_list = []
 
77
  ax.tick_params(width=0.0, labelsize=25)
78
  return fig
79
 
80
+ def sepia(input_img):
 
 
 
 
 
 
 
81
  input_img = Image.fromarray(input_img)
82
 
83
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
87
  logits = tf.transpose(logits, [0, 2, 3, 1])
88
  logits = tf.image.resize(
89
  logits, input_img.size[::-1]
90
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
91
  seg = tf.math.argmax(logits, axis=-1)[0]
92
 
93
  color_seg = np.zeros(
94
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
95
+ ) # height, width, 3
96
+ for label, color in enumerate(colormap):
97
+ color_seg[seg.numpy() == label, :] = color
 
98
 
99
+ # Show image + mask
100
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
101
  pred_img = pred_img.astype(np.uint8)
102
 
103
  fig = draw_plot(pred_img, seg)
104
  return fig
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ with gr.Blocks() as demo:
108
+ section_labels = [
109
+ "road",
110
+ "sidewalk",
111
+ "building",
112
+ "wall",
113
+ "fence",
114
+ "pole",
115
+ "traffic light",
116
+ "traffic sign",
117
+ "vegetation",
118
+ "terrain",
119
+ "sky",
120
+ "person",
121
+ "rider",
122
+ "car",
123
+ "truck",
124
+ "bus",
125
+ "train",
126
+ "motorcycle",
127
+ "bicycle"
128
+ ]
129
+
130
+ with gr.Row():
131
+ num_boxes = gr.Slider(1, 1, 1, step=0, label="Number of boxes")
132
+ num_segments = gr.Slider(0, 19, 1, step=1, label="Number of segments")
133
+
134
+ with gr.Row():
135
+ img_input = gr.Image()
136
+ img_output = gr.AnnotatedImage(
137
+ color_map={
138
+ "road": "#CC575C",
139
+ "sidewalk": "#70B9D4",
140
+ "building": "#C4A07A",
141
+ "wall": "#6A87F2",
142
+ "fence": "#5BC0DE",
143
+ "pole": "#FFC0CB",
144
+ "traffic light": "#B0E0E6",
145
+ "traffic sign": "#DE3163",
146
+ "vegetation": "#8B4513",
147
+ "terrain": "#FF0000",
148
+ "sky": "#0000FF",
149
+ "person": "#FFE4B5",
150
+ "rider": "#800000",
151
+ "car": "#008000",
152
+ "truck": "#FF6347",
153
+ "bus": "#00FF00",
154
+ "train": "#800080",
155
+ "motorcycle": "#FFFF00",
156
+ "bicycle": "#800080"}
157
+ )
158
+
159
+ section_btn = gr.Button("Identify Sections")
160
+ selected_section = gr.Textbox(label="Selected Section")
161
+
162
+
163
+ def section(img, num_boxes, num_segments):
164
+ sections = []
165
+
166
+ for a in range(num_boxes):
167
+ x = random.randint(0, img.shape[1])
168
+ y = random.randint(0, img.shape[0])
169
+ w = random.randint(0, img.shape[1] - x)
170
+ h = random.randint(0, img.shape[0] - y)
171
+ sections.append(((x, y, x + w, y + h), section_labels[a]))
172
+ for b in range(num_segments):
173
+ x = random.randint(0, img.shape[1])
174
+ y = random.randint(0, img.shape[0])
175
+ r = random.randint(0, min(x, y, img.shape[1] - x, img.shape[0] - y))
176
+ mask = np.zeros(img.shape[:2])
177
+ for i in range(img.shape[0]):
178
+ for j in range(img.shape[1]):
179
+ dist_square = (i - y) ** 2 + (j - x) ** 2
180
+ if dist_square < r ** 2:
181
+ mask[i, j] = round((r ** 2 - dist_square) / r ** 2 * 4) / 4
182
+ sections.append((mask, section_labels[b + num_boxes]))
183
+ return (img, sections)
184
+
185
+
186
+ section_btn.click(section, [img_input, num_boxes, num_segments], img_output)
187
+
188
+
189
+ def select_section(evt: gr.SelectData):
190
+ return section_labels[evt.index]
191
+
192
+
193
+ img_output.select(select_section, None, selected_section)
194
+
195
+ demo = gr.Interface(fn=sepia,
196
+ inputs=gr.Image(shape=(564,846)),
197
+ outputs=['plot'],
198
+ live=True,
199
+ examples=["city1.jpg","city2.jpg","city3.jpg"],
200
+ allow_flagging='never',
201
+ title="This is a machine learning activity project at Kyunggi University.",
202
+ theme="darkpeach",
203
+ css="""
204
+ body {
205
+ background-color: dark;
206
+ color: white; /* ํฐํŠธ ์ƒ‰์ƒ ์ˆ˜์ • */
207
+ font-family: Arial, sans-serif; /* ํฐํŠธ ํŒจ๋ฐ€๋ฆฌ ์ˆ˜์ • */
208
+ }
209
+ """
210
+
211
+ )
212
+
213
+
214
+ demo.launch()
215
+