add
Browse files- app.py +331 -0
- demo/examples/ade.jpg +0 -0
- demo/examples/coco.jpg +0 -0
- demo/examples/ego4d.jpg +0 -0
- packages.txt +4 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
|
3 |
+
#
|
4 |
+
# This work is made available under the Nvidia Source Code License.
|
5 |
+
# To view a copy of this license, visit
|
6 |
+
# https://github.com/NVlabs/ODISE/blob/main/LICENSE
|
7 |
+
#
|
8 |
+
# Written by Jiarui Xu
|
9 |
+
# ------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
import os
|
12 |
+
token = os.environ["GITHUB_TOKEN"]
|
13 |
+
os.system(f"pip install git+https://xvjiarui:{token}@github.com/xvjiarui/ODISE_NV.git")
|
14 |
+
|
15 |
+
import itertools
|
16 |
+
import json
|
17 |
+
from contextlib import ExitStack
|
18 |
+
import gradio as gr
|
19 |
+
import torch
|
20 |
+
from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES
|
21 |
+
from PIL import Image
|
22 |
+
from torch.cuda.amp import autocast
|
23 |
+
|
24 |
+
from detectron2.config import instantiate
|
25 |
+
from detectron2.data import MetadataCatalog
|
26 |
+
from detectron2.data import detection_utils as utils
|
27 |
+
from detectron2.data import transforms as T
|
28 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
29 |
+
from detectron2.evaluation import inference_context
|
30 |
+
from detectron2.utils.env import seed_all_rng
|
31 |
+
from detectron2.utils.logger import setup_logger
|
32 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer, random_color
|
33 |
+
|
34 |
+
from odise import model_zoo
|
35 |
+
from odise.checkpoint import ODISECheckpointer
|
36 |
+
from odise.config import instantiate_odise
|
37 |
+
from odise.data import get_openseg_labels
|
38 |
+
from odise.modeling.wrapper import OpenPanopticInference
|
39 |
+
from odise.utils.file_io import ODISEHandler, PathManager
|
40 |
+
from odise.model_zoo.model_zoo import _ModelZooUrls
|
41 |
+
|
42 |
+
for k in ODISEHandler.URLS:
|
43 |
+
ODISEHandler.URLS[k] = ODISEHandler.URLS[k].replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
|
44 |
+
PathManager.register_handler(ODISEHandler())
|
45 |
+
_ModelZooUrls.PREFIX = _ModelZooUrls.PREFIX.replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
|
46 |
+
|
47 |
+
setup_logger()
|
48 |
+
logger = setup_logger(name="odise")
|
49 |
+
|
50 |
+
COCO_THING_CLASSES = [
|
51 |
+
label
|
52 |
+
for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
|
53 |
+
if COCO_CATEGORIES[idx]["isthing"] == 1
|
54 |
+
]
|
55 |
+
COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1]
|
56 |
+
COCO_STUFF_CLASSES = [
|
57 |
+
label
|
58 |
+
for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
|
59 |
+
if COCO_CATEGORIES[idx]["isthing"] == 0
|
60 |
+
]
|
61 |
+
COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0]
|
62 |
+
|
63 |
+
ADE_THING_CLASSES = [
|
64 |
+
label
|
65 |
+
for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
|
66 |
+
if ADE20K_150_CATEGORIES[idx]["isthing"] == 1
|
67 |
+
]
|
68 |
+
ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1]
|
69 |
+
ADE_STUFF_CLASSES = [
|
70 |
+
label
|
71 |
+
for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
|
72 |
+
if ADE20K_150_CATEGORIES[idx]["isthing"] == 0
|
73 |
+
]
|
74 |
+
ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0]
|
75 |
+
|
76 |
+
LVIS_CLASSES = get_openseg_labels("lvis_1203", True)
|
77 |
+
# use beautiful coco colors
|
78 |
+
LVIS_COLORS = list(
|
79 |
+
itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES))
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class VisualizationDemo(object):
|
84 |
+
def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
model (nn.Module):
|
88 |
+
metadata (MetadataCatalog): image metadata.
|
89 |
+
instance_mode (ColorMode):
|
90 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
91 |
+
Useful since the visualization logic can be slow.
|
92 |
+
"""
|
93 |
+
self.model = model
|
94 |
+
self.metadata = metadata
|
95 |
+
self.aug = aug
|
96 |
+
self.cpu_device = torch.device("cpu")
|
97 |
+
self.instance_mode = instance_mode
|
98 |
+
|
99 |
+
def predict(self, original_image):
|
100 |
+
"""
|
101 |
+
Args:
|
102 |
+
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
predictions (dict):
|
106 |
+
the output of the model for one image only.
|
107 |
+
See :doc:`/tutorials/models` for details about the format.
|
108 |
+
"""
|
109 |
+
height, width = original_image.shape[:2]
|
110 |
+
aug_input = T.AugInput(original_image, sem_seg=None)
|
111 |
+
self.aug(aug_input)
|
112 |
+
image = aug_input.image
|
113 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
114 |
+
|
115 |
+
inputs = {"image": image, "height": height, "width": width}
|
116 |
+
logger.info("forwarding")
|
117 |
+
with autocast():
|
118 |
+
predictions = self.model([inputs])[0]
|
119 |
+
logger.info("done")
|
120 |
+
return predictions
|
121 |
+
|
122 |
+
def run_on_image(self, image):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
126 |
+
This is the format used by OpenCV.
|
127 |
+
Returns:
|
128 |
+
predictions (dict): the output of the model.
|
129 |
+
vis_output (VisImage): the visualized image output.
|
130 |
+
"""
|
131 |
+
vis_output = None
|
132 |
+
predictions = self.predict(image)
|
133 |
+
visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
|
134 |
+
if "panoptic_seg" in predictions:
|
135 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
136 |
+
vis_output = visualizer.draw_panoptic_seg(
|
137 |
+
panoptic_seg.to(self.cpu_device), segments_info
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
if "sem_seg" in predictions:
|
141 |
+
vis_output = visualizer.draw_sem_seg(
|
142 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
|
143 |
+
)
|
144 |
+
if "instances" in predictions:
|
145 |
+
instances = predictions["instances"].to(self.cpu_device)
|
146 |
+
vis_output = visualizer.draw_instance_predictions(predictions=instances)
|
147 |
+
|
148 |
+
return predictions, vis_output
|
149 |
+
|
150 |
+
|
151 |
+
cfg = model_zoo.get_config("Panoptic/odise_label_coco_50e.py", trained=True)
|
152 |
+
|
153 |
+
cfg.model.overlap_threshold = 0
|
154 |
+
cfg.train.device = "cuda" if torch.cuda.is_available() else "cpu"
|
155 |
+
seed_all_rng(42)
|
156 |
+
|
157 |
+
dataset_cfg = cfg.dataloader.test
|
158 |
+
wrapper_cfg = cfg.dataloader.wrapper
|
159 |
+
|
160 |
+
aug = instantiate(dataset_cfg.mapper).augmentations
|
161 |
+
|
162 |
+
model = instantiate_odise(cfg.model)
|
163 |
+
model.to(torch.float16)
|
164 |
+
model.to(cfg.train.device)
|
165 |
+
ODISECheckpointer(model).load(cfg.train.init_checkpoint)
|
166 |
+
|
167 |
+
|
168 |
+
title = "ODISE"
|
169 |
+
description = """
|
170 |
+
<p style='text-align: center'> <a href='https://jerryxu.net/ODISE' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.04803' target='_blank'>Paper</a> | <a href='https://github.com/NVlabs/ODISE' target='_blank'>Code</a> | <a href='https://youtu.be/Su7p5KYmcII' target='_blank'>Video</a></p>
|
171 |
+
|
172 |
+
Gradio demo for ODISE: Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models. \n
|
173 |
+
You may click on of the examples or upload your own image. \n
|
174 |
+
|
175 |
+
ODISE could perform open vocabulary segmentation, you may input more classes (separate by comma).
|
176 |
+
The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class.
|
177 |
+
The first word will be displayed as the class name.
|
178 |
+
""" # noqa
|
179 |
+
|
180 |
+
article = """
|
181 |
+
<p style='text-align: center'><a href='https://arxiv.org/abs/2303.04803' target='_blank'>Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models</a> | <a href='https://github.com/NVlab/ODISE' target='_blank'>Github Repo</a></p>
|
182 |
+
""" # noqa
|
183 |
+
|
184 |
+
examples = [
|
185 |
+
[
|
186 |
+
"demo/examples/coco.jpg",
|
187 |
+
"black pickup truck, pickup truck; blue sky, sky",
|
188 |
+
["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
|
189 |
+
],
|
190 |
+
[
|
191 |
+
"demo/examples/ade.jpg",
|
192 |
+
"luggage, suitcase, baggage;handbag",
|
193 |
+
["ADE (150 categories)"],
|
194 |
+
],
|
195 |
+
[
|
196 |
+
"demo/examples/ego4d.jpg",
|
197 |
+
"faucet, tap; kitchen paper, paper towels",
|
198 |
+
["COCO (133 categories)"],
|
199 |
+
],
|
200 |
+
]
|
201 |
+
|
202 |
+
|
203 |
+
def build_demo_classes_and_metadata(vocab, label_list):
|
204 |
+
extra_classes = []
|
205 |
+
|
206 |
+
if vocab:
|
207 |
+
for words in vocab.split(";"):
|
208 |
+
extra_classes.append([word.strip() for word in words.split(",")])
|
209 |
+
extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))]
|
210 |
+
|
211 |
+
demo_thing_classes = extra_classes
|
212 |
+
demo_stuff_classes = []
|
213 |
+
demo_thing_colors = extra_colors
|
214 |
+
demo_stuff_colors = []
|
215 |
+
|
216 |
+
if any("COCO" in label for label in label_list):
|
217 |
+
demo_thing_classes += COCO_THING_CLASSES
|
218 |
+
demo_stuff_classes += COCO_STUFF_CLASSES
|
219 |
+
demo_thing_colors += COCO_THING_COLORS
|
220 |
+
demo_stuff_colors += COCO_STUFF_COLORS
|
221 |
+
if any("ADE" in label for label in label_list):
|
222 |
+
demo_thing_classes += ADE_THING_CLASSES
|
223 |
+
demo_stuff_classes += ADE_STUFF_CLASSES
|
224 |
+
demo_thing_colors += ADE_THING_COLORS
|
225 |
+
demo_stuff_colors += ADE_STUFF_COLORS
|
226 |
+
if any("LVIS" in label for label in label_list):
|
227 |
+
demo_thing_classes += LVIS_CLASSES
|
228 |
+
demo_thing_colors += LVIS_COLORS
|
229 |
+
|
230 |
+
MetadataCatalog.pop("odise_demo_metadata", None)
|
231 |
+
demo_metadata = MetadataCatalog.get("odise_demo_metadata")
|
232 |
+
demo_metadata.thing_classes = [c[0] for c in demo_thing_classes]
|
233 |
+
demo_metadata.stuff_classes = [
|
234 |
+
*demo_metadata.thing_classes,
|
235 |
+
*[c[0] for c in demo_stuff_classes],
|
236 |
+
]
|
237 |
+
demo_metadata.thing_colors = demo_thing_colors
|
238 |
+
demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors
|
239 |
+
demo_metadata.stuff_dataset_id_to_contiguous_id = {
|
240 |
+
idx: idx for idx in range(len(demo_metadata.stuff_classes))
|
241 |
+
}
|
242 |
+
demo_metadata.thing_dataset_id_to_contiguous_id = {
|
243 |
+
idx: idx for idx in range(len(demo_metadata.thing_classes))
|
244 |
+
}
|
245 |
+
|
246 |
+
demo_classes = demo_thing_classes + demo_stuff_classes
|
247 |
+
|
248 |
+
return demo_classes, demo_metadata
|
249 |
+
|
250 |
+
|
251 |
+
def inference(image_path, vocab, label_list):
|
252 |
+
|
253 |
+
logger.info("building class names")
|
254 |
+
demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list)
|
255 |
+
with ExitStack() as stack:
|
256 |
+
inference_model = OpenPanopticInference(
|
257 |
+
model=model,
|
258 |
+
labels=demo_classes,
|
259 |
+
metadata=demo_metadata,
|
260 |
+
semantic_on=False,
|
261 |
+
instance_on=False,
|
262 |
+
panoptic_on=True,
|
263 |
+
)
|
264 |
+
stack.enter_context(inference_context(inference_model))
|
265 |
+
stack.enter_context(torch.no_grad())
|
266 |
+
|
267 |
+
demo = VisualizationDemo(inference_model, demo_metadata, aug)
|
268 |
+
img = utils.read_image(image_path, format="RGB")
|
269 |
+
_, visualized_output = demo.run_on_image(img)
|
270 |
+
return Image.fromarray(visualized_output.get_image())
|
271 |
+
|
272 |
+
|
273 |
+
with gr.Blocks(title=title) as demo:
|
274 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
|
275 |
+
gr.Markdown(description)
|
276 |
+
input_components = []
|
277 |
+
output_components = []
|
278 |
+
|
279 |
+
with gr.Row():
|
280 |
+
output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil")
|
281 |
+
output_components.append(output_image_gr)
|
282 |
+
|
283 |
+
with gr.Row().style(equal_height=True, mobile_collapse=True):
|
284 |
+
with gr.Column(scale=3, variant="panel") as input_component_column:
|
285 |
+
input_image_gr = gr.inputs.Image(type="filepath")
|
286 |
+
extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary")
|
287 |
+
category_list_gr = gr.inputs.CheckboxGroup(
|
288 |
+
choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
|
289 |
+
default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
|
290 |
+
label="Category to use",
|
291 |
+
)
|
292 |
+
input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr])
|
293 |
+
|
294 |
+
with gr.Column(scale=2):
|
295 |
+
examples_handler = gr.Examples(
|
296 |
+
examples=examples,
|
297 |
+
inputs=[c for c in input_components if not isinstance(c, gr.State)],
|
298 |
+
outputs=[c for c in output_components if not isinstance(c, gr.State)],
|
299 |
+
fn=inference,
|
300 |
+
cache_examples=torch.cuda.is_available(),
|
301 |
+
examples_per_page=5,
|
302 |
+
)
|
303 |
+
with gr.Row():
|
304 |
+
clear_btn = gr.Button("Clear")
|
305 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
306 |
+
|
307 |
+
gr.Markdown(article)
|
308 |
+
|
309 |
+
submit_btn.click(
|
310 |
+
inference,
|
311 |
+
input_components,
|
312 |
+
output_components,
|
313 |
+
api_name="predict",
|
314 |
+
scroll_to_output=True,
|
315 |
+
)
|
316 |
+
|
317 |
+
clear_btn.click(
|
318 |
+
None,
|
319 |
+
[],
|
320 |
+
(input_components + output_components + [input_component_column]),
|
321 |
+
_js=f"""() => {json.dumps(
|
322 |
+
[component.cleared_value if hasattr(component, "cleared_value") else None
|
323 |
+
for component in input_components + output_components] + (
|
324 |
+
[gr.Column.update(visible=True)]
|
325 |
+
)
|
326 |
+
+ ([gr.Column.update(visible=False)])
|
327 |
+
)}
|
328 |
+
""",
|
329 |
+
)
|
330 |
+
|
331 |
+
demo.launch()
|
demo/examples/ade.jpg
ADDED
![]() |
demo/examples/coco.jpg
ADDED
![]() |
demo/examples/ego4d.jpg
ADDED
![]() |
packages.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libtinfo5
|
2 |
+
libsm6
|
3 |
+
libxext6
|
4 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
2 |
+
torch==1.13.1+cu116
|
3 |
+
torchvision==0.14.1+cu116
|
4 |
+
xformers==0.0.16
|
5 |
+
numpy<=1.21.5
|