Spaces:
Paused
Paused
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import argparse | |
import os | |
from itertools import chain | |
import cv2 | |
import tqdm | |
from detectron2.config import get_cfg | |
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_train_loader | |
from detectron2.data import detection_utils as utils | |
from detectron2.data.build import filter_images_with_few_keypoints | |
from detectron2.utils.logger import setup_logger | |
from detectron2.utils.visualizer import Visualizer | |
def setup(args): | |
cfg = get_cfg() | |
if args.config_file: | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
return cfg | |
def parse_args(in_args=None): | |
parser = argparse.ArgumentParser(description="Visualize ground-truth data") | |
parser.add_argument( | |
"--source", | |
choices=["annotation", "dataloader"], | |
required=True, | |
help="visualize the annotations or the data loader (with pre-processing)", | |
) | |
parser.add_argument("--config-file", metavar="FILE", help="path to config file") | |
parser.add_argument("--output-dir", default="./", help="path to output directory") | |
parser.add_argument("--show", action="store_true", help="show output in a window") | |
parser.add_argument( | |
"opts", | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
return parser.parse_args(in_args) | |
if __name__ == "__main__": | |
args = parse_args() | |
logger = setup_logger() | |
logger.info("Arguments: " + str(args)) | |
cfg = setup(args) | |
dirname = args.output_dir | |
os.makedirs(dirname, exist_ok=True) | |
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) | |
def output(vis, fname): | |
if args.show: | |
print(fname) | |
cv2.imshow("window", vis.get_image()[:, :, ::-1]) | |
cv2.waitKey() | |
else: | |
filepath = os.path.join(dirname, fname) | |
print("Saving to {} ...".format(filepath)) | |
vis.save(filepath) | |
scale = 2.0 if args.show else 1.0 | |
if args.source == "dataloader": | |
train_data_loader = build_detection_train_loader(cfg) | |
for batch in train_data_loader: | |
for per_image in batch: | |
# Pytorch tensor is in (C, H, W) format | |
img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy() | |
img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT) | |
visualizer = Visualizer(img, metadata=metadata, scale=scale) | |
target_fields = per_image["instances"].get_fields() | |
labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] | |
vis = visualizer.overlay_instances( | |
labels=labels, | |
boxes=target_fields.get("gt_boxes", None), | |
masks=target_fields.get("gt_masks", None), | |
keypoints=target_fields.get("gt_keypoints", None), | |
) | |
output(vis, str(per_image["image_id"]) + ".jpg") | |
else: | |
dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN])) | |
if cfg.MODEL.KEYPOINT_ON: | |
dicts = filter_images_with_few_keypoints(dicts, 1) | |
for dic in tqdm.tqdm(dicts): | |
img = utils.read_image(dic["file_name"], "RGB") | |
visualizer = Visualizer(img, metadata=metadata, scale=scale) | |
vis = visualizer.draw_dataset_dict(dic) | |
output(vis, os.path.basename(dic["file_name"])) | |