Hantr commited on
Commit
f3c8437
ยท
1 Parent(s): f47dbf7
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -6,14 +6,16 @@ import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
 
9
 
10
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
  "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
12
  )
13
- model = TFSegformerForSemanticSegmentation.from_pretrained(
14
  "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
15
  )
16
 
 
17
 
18
  def ade_palette():
19
  """ADE20K palette that maps each class to RGB values."""
@@ -82,7 +84,7 @@ def sepia(input_img):
82
  input_img = Image.fromarray(input_img)
83
 
84
  inputs = feature_extractor(images=input_img, return_tensors="tf")
85
- outputs = model(**inputs)
86
  logits = outputs.logits
87
 
88
  logits = tf.transpose(logits, [0, 2, 3, 1])
@@ -105,11 +107,12 @@ def sepia(input_img):
105
  return fig
106
 
107
 
108
- def with_labels(input_img):
109
  input_img = Image.fromarray(input_img)
110
 
 
111
  inputs = feature_extractor(images=input_img, return_tensors="tf")
112
- outputs = model(**inputs)
113
  logits = outputs.logits
114
 
115
  logits = tf.transpose(logits, [0, 2, 3, 1])
@@ -118,21 +121,25 @@ def with_labels(input_img):
118
  )
119
  seg = tf.math.argmax(logits, axis=-1)[0]
120
 
121
- color_seg = np.zeros(
122
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
123
- )
124
- for label, color in enumerate(colormap):
125
- color_seg[seg.numpy() == label, :] = color
 
 
 
 
 
 
 
126
 
127
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
128
- pred_img = pred_img.astype(np.uint8)
129
 
130
- return input_img, pred_img, labels_list
131
 
132
 
133
- demo = gr.Interface(fn=with_labels,
134
  inputs=gr.Image(shape=(1024, 1024)),
135
- outputs=["image", "image", "text"],
136
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg", "city-4.jpg", "city-5.jpg"],
137
  allow_flagging='never')
138
 
 
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
+ from transformers import BigBirdForImageCaptioning
10
 
11
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
12
  "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
13
  )
14
+ seg_model = TFSegformerForSemanticSegmentation.from_pretrained(
15
  "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
16
  )
17
 
18
+ caption_model = BigBirdForImageCaptioning.from_pretrained("bigbird/image-captioning-base")
19
 
20
  def ade_palette():
21
  """ADE20K palette that maps each class to RGB values."""
 
84
  input_img = Image.fromarray(input_img)
85
 
86
  inputs = feature_extractor(images=input_img, return_tensors="tf")
87
+ outputs = seg_model(**inputs)
88
  logits = outputs.logits
89
 
90
  logits = tf.transpose(logits, [0, 2, 3, 1])
 
107
  return fig
108
 
109
 
110
+ def segment_and_caption(input_img):
111
  input_img = Image.fromarray(input_img)
112
 
113
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
114
  inputs = feature_extractor(images=input_img, return_tensors="tf")
115
+ outputs = seg_model(**inputs)
116
  logits = outputs.logits
117
 
118
  logits = tf.transpose(logits, [0, 2, 3, 1])
 
121
  )
122
  seg = tf.math.argmax(logits, axis=-1)[0]
123
 
124
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๊ฒฐ๊ณผ๋ฅผ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜
125
+ seg_text = ""
126
+ for label, label_name in enumerate(labels_list):
127
+ count = np.sum(seg.numpy() == label)
128
+ seg_text += f"{label_name}: {count} pixels\n"
129
+
130
+ # ์ด๋ฏธ์ง€ ์บก์…˜ ์ƒ์„ฑ
131
+ caption = caption_model.generate(input_img, max_length=20, num_return_sequences=1, return_dict_in_generate=True)
132
+ caption_text = caption[0]['text']
133
+
134
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๊ฒฐ๊ณผ์™€ ์บก์…˜์„ ๋ฐ˜ํ™˜
135
+ return input_img, seg_text, caption_text
136
 
 
 
137
 
 
138
 
139
 
140
+ demo = gr.Interface(fn=segment_and_caption,
141
  inputs=gr.Image(shape=(1024, 1024)),
142
+ outputs=["image","text", "text"],
143
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg", "city-4.jpg", "city-5.jpg"],
144
  allow_flagging='never')
145