mokazht commited on
Commit
13d8909
·
1 Parent(s): f58cc9e
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
3
+ #
4
+ # This work is made available under the Nvidia Source Code License.
5
+ # To view a copy of this license, visit
6
+ # https://github.com/NVlabs/ODISE/blob/main/LICENSE
7
+ #
8
+ # Written by Jiarui Xu
9
+ # ------------------------------------------------------------------------------
10
+
11
+ import os
12
+ token = os.environ["GITHUB_TOKEN"]
13
+ os.system(f"pip install git+https://xvjiarui:{token}@github.com/xvjiarui/ODISE_NV.git")
14
+
15
+ import itertools
16
+ import json
17
+ from contextlib import ExitStack
18
+ import gradio as gr
19
+ import torch
20
+ from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES
21
+ from PIL import Image
22
+ from torch.cuda.amp import autocast
23
+
24
+ from detectron2.config import instantiate
25
+ from detectron2.data import MetadataCatalog
26
+ from detectron2.data import detection_utils as utils
27
+ from detectron2.data import transforms as T
28
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
29
+ from detectron2.evaluation import inference_context
30
+ from detectron2.utils.env import seed_all_rng
31
+ from detectron2.utils.logger import setup_logger
32
+ from detectron2.utils.visualizer import ColorMode, Visualizer, random_color
33
+
34
+ from odise import model_zoo
35
+ from odise.checkpoint import ODISECheckpointer
36
+ from odise.config import instantiate_odise
37
+ from odise.data import get_openseg_labels
38
+ from odise.modeling.wrapper import OpenPanopticInference
39
+ from odise.utils.file_io import ODISEHandler, PathManager
40
+ from odise.model_zoo.model_zoo import _ModelZooUrls
41
+
42
+ for k in ODISEHandler.URLS:
43
+ ODISEHandler.URLS[k] = ODISEHandler.URLS[k].replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
44
+ PathManager.register_handler(ODISEHandler())
45
+ _ModelZooUrls.PREFIX = _ModelZooUrls.PREFIX.replace("https://github.com/NVlabs/ODISE/releases/download/v1.0.0/", "https://huggingface.co/xvjiarui/download_cache/resolve/main/torch/odise/")
46
+
47
+ setup_logger()
48
+ logger = setup_logger(name="odise")
49
+
50
+ COCO_THING_CLASSES = [
51
+ label
52
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
53
+ if COCO_CATEGORIES[idx]["isthing"] == 1
54
+ ]
55
+ COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1]
56
+ COCO_STUFF_CLASSES = [
57
+ label
58
+ for idx, label in enumerate(get_openseg_labels("coco_panoptic", True))
59
+ if COCO_CATEGORIES[idx]["isthing"] == 0
60
+ ]
61
+ COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0]
62
+
63
+ ADE_THING_CLASSES = [
64
+ label
65
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
66
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 1
67
+ ]
68
+ ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1]
69
+ ADE_STUFF_CLASSES = [
70
+ label
71
+ for idx, label in enumerate(get_openseg_labels("ade20k_150", True))
72
+ if ADE20K_150_CATEGORIES[idx]["isthing"] == 0
73
+ ]
74
+ ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0]
75
+
76
+ LVIS_CLASSES = get_openseg_labels("lvis_1203", True)
77
+ # use beautiful coco colors
78
+ LVIS_COLORS = list(
79
+ itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES))
80
+ )
81
+
82
+
83
+ class VisualizationDemo(object):
84
+ def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE):
85
+ """
86
+ Args:
87
+ model (nn.Module):
88
+ metadata (MetadataCatalog): image metadata.
89
+ instance_mode (ColorMode):
90
+ parallel (bool): whether to run the model in different processes from visualization.
91
+ Useful since the visualization logic can be slow.
92
+ """
93
+ self.model = model
94
+ self.metadata = metadata
95
+ self.aug = aug
96
+ self.cpu_device = torch.device("cpu")
97
+ self.instance_mode = instance_mode
98
+
99
+ def predict(self, original_image):
100
+ """
101
+ Args:
102
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
103
+
104
+ Returns:
105
+ predictions (dict):
106
+ the output of the model for one image only.
107
+ See :doc:`/tutorials/models` for details about the format.
108
+ """
109
+ height, width = original_image.shape[:2]
110
+ aug_input = T.AugInput(original_image, sem_seg=None)
111
+ self.aug(aug_input)
112
+ image = aug_input.image
113
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
114
+
115
+ inputs = {"image": image, "height": height, "width": width}
116
+ logger.info("forwarding")
117
+ with autocast():
118
+ predictions = self.model([inputs])[0]
119
+ logger.info("done")
120
+ return predictions
121
+
122
+ def run_on_image(self, image):
123
+ """
124
+ Args:
125
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
126
+ This is the format used by OpenCV.
127
+ Returns:
128
+ predictions (dict): the output of the model.
129
+ vis_output (VisImage): the visualized image output.
130
+ """
131
+ vis_output = None
132
+ predictions = self.predict(image)
133
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
134
+ if "panoptic_seg" in predictions:
135
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
136
+ vis_output = visualizer.draw_panoptic_seg(
137
+ panoptic_seg.to(self.cpu_device), segments_info
138
+ )
139
+ else:
140
+ if "sem_seg" in predictions:
141
+ vis_output = visualizer.draw_sem_seg(
142
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
143
+ )
144
+ if "instances" in predictions:
145
+ instances = predictions["instances"].to(self.cpu_device)
146
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
147
+
148
+ return predictions, vis_output
149
+
150
+
151
+ cfg = model_zoo.get_config("Panoptic/odise_label_coco_50e.py", trained=True)
152
+
153
+ cfg.model.overlap_threshold = 0
154
+ cfg.train.device = "cuda" if torch.cuda.is_available() else "cpu"
155
+ seed_all_rng(42)
156
+
157
+ dataset_cfg = cfg.dataloader.test
158
+ wrapper_cfg = cfg.dataloader.wrapper
159
+
160
+ aug = instantiate(dataset_cfg.mapper).augmentations
161
+
162
+ model = instantiate_odise(cfg.model)
163
+ model.to(torch.float16)
164
+ model.to(cfg.train.device)
165
+ ODISECheckpointer(model).load(cfg.train.init_checkpoint)
166
+
167
+
168
+ title = "ODISE"
169
+ description = """
170
+ <p style='text-align: center'> <a href='https://jerryxu.net/ODISE' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2303.04803' target='_blank'>Paper</a> | <a href='https://github.com/NVlabs/ODISE' target='_blank'>Code</a> | <a href='https://youtu.be/Su7p5KYmcII' target='_blank'>Video</a></p>
171
+
172
+ Gradio demo for ODISE: Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models. \n
173
+ You may click on of the examples or upload your own image. \n
174
+
175
+ ODISE could perform open vocabulary segmentation, you may input more classes (separate by comma).
176
+ The expected format is 'a1,a2;b1,b2', where a1,a2 are synonyms vocabularies for the first class.
177
+ The first word will be displayed as the class name.
178
+ """ # noqa
179
+
180
+ article = """
181
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2303.04803' target='_blank'>Open-Vocabulary Panoptic Segmentation with Text-to-Image Diffusion Models</a> | <a href='https://github.com/NVlab/ODISE' target='_blank'>Github Repo</a></p>
182
+ """ # noqa
183
+
184
+ examples = [
185
+ [
186
+ "demo/examples/coco.jpg",
187
+ "black pickup truck, pickup truck; blue sky, sky",
188
+ ["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
189
+ ],
190
+ [
191
+ "demo/examples/ade.jpg",
192
+ "luggage, suitcase, baggage;handbag",
193
+ ["ADE (150 categories)"],
194
+ ],
195
+ [
196
+ "demo/examples/ego4d.jpg",
197
+ "faucet, tap; kitchen paper, paper towels",
198
+ ["COCO (133 categories)"],
199
+ ],
200
+ ]
201
+
202
+
203
+ def build_demo_classes_and_metadata(vocab, label_list):
204
+ extra_classes = []
205
+
206
+ if vocab:
207
+ for words in vocab.split(";"):
208
+ extra_classes.append([word.strip() for word in words.split(",")])
209
+ extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))]
210
+
211
+ demo_thing_classes = extra_classes
212
+ demo_stuff_classes = []
213
+ demo_thing_colors = extra_colors
214
+ demo_stuff_colors = []
215
+
216
+ if any("COCO" in label for label in label_list):
217
+ demo_thing_classes += COCO_THING_CLASSES
218
+ demo_stuff_classes += COCO_STUFF_CLASSES
219
+ demo_thing_colors += COCO_THING_COLORS
220
+ demo_stuff_colors += COCO_STUFF_COLORS
221
+ if any("ADE" in label for label in label_list):
222
+ demo_thing_classes += ADE_THING_CLASSES
223
+ demo_stuff_classes += ADE_STUFF_CLASSES
224
+ demo_thing_colors += ADE_THING_COLORS
225
+ demo_stuff_colors += ADE_STUFF_COLORS
226
+ if any("LVIS" in label for label in label_list):
227
+ demo_thing_classes += LVIS_CLASSES
228
+ demo_thing_colors += LVIS_COLORS
229
+
230
+ MetadataCatalog.pop("odise_demo_metadata", None)
231
+ demo_metadata = MetadataCatalog.get("odise_demo_metadata")
232
+ demo_metadata.thing_classes = [c[0] for c in demo_thing_classes]
233
+ demo_metadata.stuff_classes = [
234
+ *demo_metadata.thing_classes,
235
+ *[c[0] for c in demo_stuff_classes],
236
+ ]
237
+ demo_metadata.thing_colors = demo_thing_colors
238
+ demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors
239
+ demo_metadata.stuff_dataset_id_to_contiguous_id = {
240
+ idx: idx for idx in range(len(demo_metadata.stuff_classes))
241
+ }
242
+ demo_metadata.thing_dataset_id_to_contiguous_id = {
243
+ idx: idx for idx in range(len(demo_metadata.thing_classes))
244
+ }
245
+
246
+ demo_classes = demo_thing_classes + demo_stuff_classes
247
+
248
+ return demo_classes, demo_metadata
249
+
250
+
251
+ def inference(image_path, vocab, label_list):
252
+
253
+ logger.info("building class names")
254
+ demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list)
255
+ with ExitStack() as stack:
256
+ inference_model = OpenPanopticInference(
257
+ model=model,
258
+ labels=demo_classes,
259
+ metadata=demo_metadata,
260
+ semantic_on=False,
261
+ instance_on=False,
262
+ panoptic_on=True,
263
+ )
264
+ stack.enter_context(inference_context(inference_model))
265
+ stack.enter_context(torch.no_grad())
266
+
267
+ demo = VisualizationDemo(inference_model, demo_metadata, aug)
268
+ img = utils.read_image(image_path, format="RGB")
269
+ _, visualized_output = demo.run_on_image(img)
270
+ return Image.fromarray(visualized_output.get_image())
271
+
272
+
273
+ with gr.Blocks(title=title) as demo:
274
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
275
+ gr.Markdown(description)
276
+ input_components = []
277
+ output_components = []
278
+
279
+ with gr.Row():
280
+ output_image_gr = gr.outputs.Image(label="Panoptic Segmentation", type="pil")
281
+ output_components.append(output_image_gr)
282
+
283
+ with gr.Row().style(equal_height=True, mobile_collapse=True):
284
+ with gr.Column(scale=3, variant="panel") as input_component_column:
285
+ input_image_gr = gr.inputs.Image(type="filepath")
286
+ extra_vocab_gr = gr.inputs.Textbox(default="", label="Extra Vocabulary")
287
+ category_list_gr = gr.inputs.CheckboxGroup(
288
+ choices=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
289
+ default=["COCO (133 categories)", "ADE (150 categories)", "LVIS (1203 categories)"],
290
+ label="Category to use",
291
+ )
292
+ input_components.extend([input_image_gr, extra_vocab_gr, category_list_gr])
293
+
294
+ with gr.Column(scale=2):
295
+ examples_handler = gr.Examples(
296
+ examples=examples,
297
+ inputs=[c for c in input_components if not isinstance(c, gr.State)],
298
+ outputs=[c for c in output_components if not isinstance(c, gr.State)],
299
+ fn=inference,
300
+ cache_examples=torch.cuda.is_available(),
301
+ examples_per_page=5,
302
+ )
303
+ with gr.Row():
304
+ clear_btn = gr.Button("Clear")
305
+ submit_btn = gr.Button("Submit", variant="primary")
306
+
307
+ gr.Markdown(article)
308
+
309
+ submit_btn.click(
310
+ inference,
311
+ input_components,
312
+ output_components,
313
+ api_name="predict",
314
+ scroll_to_output=True,
315
+ )
316
+
317
+ clear_btn.click(
318
+ None,
319
+ [],
320
+ (input_components + output_components + [input_component_column]),
321
+ _js=f"""() => {json.dumps(
322
+ [component.cleared_value if hasattr(component, "cleared_value") else None
323
+ for component in input_components + output_components] + (
324
+ [gr.Column.update(visible=True)]
325
+ )
326
+ + ([gr.Column.update(visible=False)])
327
+ )}
328
+ """,
329
+ )
330
+
331
+ demo.launch()
demo/examples/ade.jpg ADDED
demo/examples/coco.jpg ADDED
demo/examples/ego4d.jpg ADDED
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libtinfo5
2
+ libsm6
3
+ libxext6
4
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch==1.13.1+cu116
3
+ torchvision==0.14.1+cu116
4
+ xformers==0.0.16
5
+ numpy<=1.21.5