Daoneeee commited on
Commit
acb5613
ยท
1 Parent(s): d82dbc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -32
app.py CHANGED
@@ -1,21 +1,18 @@
1
  import gradio as gr
2
-
3
- from matplotlib import gridspec
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
- from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
 
 
 
 
 
 
 
 
9
 
10
- feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
- "mattmdjaga/segformer_b2_clothes"
12
- )
13
- model = TFSegformerForSemanticSegmentation.from_pretrained(
14
- "mattmdjaga/segformer_b2_clothes"
15
- )
16
 
 
17
  def ade_palette():
18
- """ADE20K palette that maps each class to RGB values."""
19
  return [
20
  [34, 116, 28],
21
  [84, 57, 0],
@@ -37,25 +34,28 @@ def ade_palette():
37
  [206, 114, 61],
38
  ]
39
 
 
40
  labels_list = []
41
 
42
- with open(r'labels.txt', 'r') as fp:
43
  for line in fp:
44
  labels_list.append(line[:-1])
45
 
46
  colormap = np.asarray(ade_palette())
47
 
 
 
48
  def label_to_color_image(label):
49
  if label.ndim != 2:
50
  raise ValueError("Expect 2-D input label")
51
-
52
  if np.max(label) >= len(colormap):
53
  raise ValueError("label value too large.")
54
  return colormap[label]
55
 
 
 
56
  def draw_plot(pred_img, seg):
57
  fig = plt.figure(figsize=(20, 15))
58
-
59
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
60
 
61
  plt.subplot(grid_spec[0])
@@ -64,7 +64,6 @@ def draw_plot(pred_img, seg):
64
  LABEL_NAMES = np.asarray(labels_list)
65
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
66
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
67
-
68
  unique_labels = np.unique(seg.numpy().astype("uint8"))
69
  ax = plt.subplot(grid_spec[1])
70
  plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
@@ -75,40 +74,40 @@ def draw_plot(pred_img, seg):
75
  return fig
76
 
77
 
78
- def sepia(input_img):
 
79
  input_img = Image.fromarray(input_img)
80
 
 
81
  inputs = feature_extractor(images=input_img, return_tensors="tf")
82
  outputs = model(**inputs)
83
  logits = outputs.logits
84
-
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
 
87
- # ํฌ๊ธฐ ์กฐ์ • ์ฝ”๋“œ ์ถ”๊ฐ€
88
- logits = tf.image.resize(
89
- logits, [input_img.size[1], input_img.size[0]]
90
- )
91
 
92
  seg = tf.math.argmax(logits, axis=-1)[0]
 
93
 
94
- color_seg = np.zeros(
95
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
96
- ) # height, width, 3
97
  for label, color in enumerate(colormap):
98
  color_seg[seg.numpy() == label, :] = color
99
 
100
- # Show image + mask
101
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
102
  pred_img = pred_img.astype(np.uint8)
103
 
104
  fig = draw_plot(pred_img, seg)
105
  return fig
106
 
107
- demo = gr.Interface(fn=sepia,
108
- inputs=gr.Image(type='pil', preprocess=None),
109
- outputs=['plot'],
110
- examples=["person-1.jpg", "person-2.jpg", "person-3.jpg", "person-4.jpg", "person-5.jpg"],
111
- allow_flagging='never')
112
 
113
- demo.launch()
 
 
 
 
 
 
 
114
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
  import tensorflow as tf
3
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
4
+ import numpy as np
5
+ from PIL import Image
6
+ from matplotlib import gridspec
7
+ import matplotlib.pyplot as plt
8
+
9
+ # ๋ชจ๋ธ ๋ฐ ํŠน์„ฑ ์ถ”์ถœ๊ธฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค
10
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
11
+ model = TFSegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
12
 
 
 
 
 
 
 
13
 
14
+ # ADE20K ํŒ”๋ ˆํŠธ ๋ฐ ๋ผ๋ฒจ ๋ชฉ๋ก์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค
15
  def ade_palette():
 
16
  return [
17
  [34, 116, 28],
18
  [84, 57, 0],
 
34
  [206, 114, 61],
35
  ]
36
 
37
+
38
  labels_list = []
39
 
40
+ with open('labels.txt', 'r') as fp:
41
  for line in fp:
42
  labels_list.append(line[:-1])
43
 
44
  colormap = np.asarray(ade_palette())
45
 
46
+
47
+ # ๋ผ๋ฒจ์„ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค
48
  def label_to_color_image(label):
49
  if label.ndim != 2:
50
  raise ValueError("Expect 2-D input label")
 
51
  if np.max(label) >= len(colormap):
52
  raise ValueError("label value too large.")
53
  return colormap[label]
54
 
55
+
56
+ # ์˜ˆ์ธก ์ด๋ฏธ์ง€์™€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์„ ์‹œ๊ฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค
57
  def draw_plot(pred_img, seg):
58
  fig = plt.figure(figsize=(20, 15))
 
59
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
60
 
61
  plt.subplot(grid_spec[0])
 
64
  LABEL_NAMES = np.asarray(labels_list)
65
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
66
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
 
67
  unique_labels = np.unique(seg.numpy().astype("uint8"))
68
  ax = plt.subplot(grid_spec[1])
69
  plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
 
74
  return fig
75
 
76
 
77
+ # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค
78
+ def preprocess_image(input_img):
79
  input_img = Image.fromarray(input_img)
80
 
81
+ # ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ์˜ ์ž…๋ ฅ ํ˜•์‹์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค
82
  inputs = feature_extractor(images=input_img, return_tensors="tf")
83
  outputs = model(**inputs)
84
  logits = outputs.logits
 
85
  logits = tf.transpose(logits, [0, 2, 3, 1])
86
 
87
+ # ํฌ๊ธฐ ์กฐ์ •
88
+ logits = tf.image.resize(logits, [input_img.size[1], input_img.size[0]])
 
 
89
 
90
  seg = tf.math.argmax(logits, axis=-1)[0]
91
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
92
 
 
 
 
93
  for label, color in enumerate(colormap):
94
  color_seg[seg.numpy() == label, :] = color
95
 
 
96
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
97
  pred_img = pred_img.astype(np.uint8)
98
 
99
  fig = draw_plot(pred_img, seg)
100
  return fig
101
 
 
 
 
 
 
102
 
103
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค
104
+ demo = gr.Interface(
105
+ fn=preprocess_image,
106
+ inputs=gr.Image(type='pil'),
107
+ outputs=['plot'],
108
+ examples=["person-1.jpg", "person-2.jpg", "person-3.jpg", "person-4.jpg", "person-5.jpg"],
109
+ allow_flagging='never'
110
+ )
111
 
112
+ # ์•ฑ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค
113
+ demo.launch()