Karin0616 commited on
Commit
86b082c
ยท
1 Parent(s): 2bee93c
Files changed (1) hide show
  1. app.py +50 -42
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import gradio as gr
2
-
3
-
4
  from matplotlib import gridspec
5
  import matplotlib.pyplot as plt
6
  import numpy as np
@@ -16,28 +14,26 @@ model = TFSegformerForSemanticSegmentation.from_pretrained(
16
  )
17
 
18
  def ade_palette():
19
-
20
  return [
21
- [204, 87, 92], # road (Reddish)
22
  [112, 185, 212], # sidewalk (Blue)
23
  [196, 160, 122], # building (Brown)
24
  [106, 135, 242], # wall (Light Blue)
25
- [91, 192, 222], # fence (Turquoise)
26
  [255, 192, 203], # pole (Pink)
27
  [176, 224, 230], # traffic light (Light Blue)
28
- [222, 49, 99], # traffic sign (Red)
29
- [139, 69, 19], # vegetation (Brown)
30
- [255, 0, 0], # terrain (Red)
31
- [0, 0, 255], # sky (Blue)
32
  [255, 228, 181], # person (Peach)
33
- [128, 0, 0], # rider (Maroon)
34
- [0, 128, 0], # car (Green)
35
- [255, 99, 71], # truck (Tomato)
36
- [0, 255, 0], # bus (Lime)
37
- [128, 0, 128], # train (Purple)
38
- [255, 255, 0], # motorcycle (Yellow)
39
- [128, 0, 128] # bicycle (Purple)
40
-
41
  ]
42
 
43
  labels_list = []
@@ -77,7 +73,14 @@ def draw_plot(pred_img, seg):
77
  ax.tick_params(width=0.0, labelsize=25)
78
  return fig
79
 
80
- def sepia(input_img):
 
 
 
 
 
 
 
81
  input_img = Image.fromarray(input_img)
82
 
83
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -87,38 +90,43 @@ def sepia(input_img):
87
  logits = tf.transpose(logits, [0, 2, 3, 1])
88
  logits = tf.image.resize(
89
  logits, input_img.size[::-1]
90
- ) # We reverse the shape of `image` because `image.size` returns width and height.
91
  seg = tf.math.argmax(logits, axis=-1)[0]
92
 
93
  color_seg = np.zeros(
94
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
95
- ) # height, width, 3
96
- for label, color in enumerate(colormap):
97
- color_seg[seg.numpy() == label, :] = color
 
98
 
99
- # Show image + mask
100
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
101
  pred_img = pred_img.astype(np.uint8)
102
 
103
  fig = draw_plot(pred_img, seg)
104
  return fig
105
 
106
- demo = gr.Interface(fn=sepia,
107
- inputs=gr.Image(shape=(564,846)),
108
- outputs=['plot'],
109
- live=True,
110
- examples=["city1.jpg","city2.jpg","city3.jpg"],
111
- allow_flagging='never',
112
- title="This is a machine learning activity project at Kyunggi University.",
113
- theme="darkpeach",
114
- css="""
115
- body {
116
- background-color: dark;
117
- color: white; /* ํฐํŠธ ์ƒ‰์ƒ ์ˆ˜์ • */
118
- font-family: Arial, sans-serif; /* ํฐํŠธ ํŒจ๋ฐ€๋ฆฌ ์ˆ˜์ • */
119
- }
120
- """
121
- )
122
-
123
-
124
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  from matplotlib import gridspec
3
  import matplotlib.pyplot as plt
4
  import numpy as np
 
14
  )
15
 
16
  def ade_palette():
 
17
  return [
18
+ [204, 87, 92], # road (Reddish)
19
  [112, 185, 212], # sidewalk (Blue)
20
  [196, 160, 122], # building (Brown)
21
  [106, 135, 242], # wall (Light Blue)
22
+ [91, 192, 222], # fence (Turquoise)
23
  [255, 192, 203], # pole (Pink)
24
  [176, 224, 230], # traffic light (Light Blue)
25
+ [222, 49, 99], # traffic sign (Red)
26
+ [139, 69, 19], # vegetation (Brown)
27
+ [255, 0, 0], # terrain (Red)
28
+ [0, 0, 255], # sky (Blue)
29
  [255, 228, 181], # person (Peach)
30
+ [128, 0, 0], # rider (Maroon)
31
+ [0, 128, 0], # car (Green)
32
+ [255, 99, 71], # truck (Tomato)
33
+ [0, 255, 0], # bus (Lime)
34
+ [128, 0, 128], # train (Purple)
35
+ [255, 255, 0], # motorcycle (Yellow)
36
+ [128, 0, 128] # bicycle (Purple)
 
37
  ]
38
 
39
  labels_list = []
 
73
  ax.tick_params(width=0.0, labelsize=25)
74
  return fig
75
 
76
+ def sepia(input_img, *label_buttons):
77
+ selected_color = None
78
+ for label, button_state in zip(labels_list, label_buttons):
79
+ if button_state:
80
+ label_index = labels_list.index(label)
81
+ selected_color = colormap[label_index]
82
+ break
83
+
84
  input_img = Image.fromarray(input_img)
85
 
86
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
90
  logits = tf.transpose(logits, [0, 2, 3, 1])
91
  logits = tf.image.resize(
92
  logits, input_img.size[::-1]
93
+ )
94
  seg = tf.math.argmax(logits, axis=-1)[0]
95
 
96
  color_seg = np.zeros(
97
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
98
+ )
99
+ if selected_color:
100
+ label = colormap.index(selected_color)
101
+ color_seg[seg.numpy() == label, :] = selected_color
102
 
 
103
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
104
  pred_img = pred_img.astype(np.uint8)
105
 
106
  fig = draw_plot(pred_img, seg)
107
  return fig
108
 
109
+ # ๋ผ๋ฒจ ๋ฒ„ํŠผ ์ƒ์„ฑ
110
+ label_buttons = [gr.Button(label) for label in labels_list]
111
+
112
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
113
+ iface = gr.Interface(
114
+ fn=sepia,
115
+ inputs=[gr.Image(shape=(564, 846))] + label_buttons,
116
+ outputs="plot",
117
+ live=True,
118
+ examples=["city1.jpg", "city2.jpg", "city3.jpg"],
119
+ allow_flagging='never',
120
+ title="This is a machine learning activity project at Kyunggi University.",
121
+ theme="darkpeach",
122
+ css="""
123
+ body{
124
+ background-color: dark;
125
+ color: white;
126
+ font-family: Arial, sans-serif;
127
+ }
128
+ """
129
+ )
130
+
131
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹œ์ž‘
132
+ iface.launch()