Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import logging | |
import os | |
import argparse | |
import sys | |
import numpy as np | |
from collections import OrderedDict | |
import torch | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.engine import default_argument_parser, default_setup, launch | |
from detectron2.data import transforms as T | |
logger = logging.getLogger("detectron2") | |
sys.dont_write_bytecode = True | |
sys.path.append(os.getcwd()) | |
np.set_printoptions(suppress=True) | |
from cubercnn.config import get_cfg_defaults | |
from cubercnn.modeling.proposal_generator import RPNWithIgnore | |
from cubercnn.modeling.roi_heads import ROIHeads3D | |
from cubercnn.modeling.meta_arch import RCNN3D, build_model | |
from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone | |
from cubercnn import util, vis | |
def do_test(args, cfg, model): | |
list_of_ims = util.list_files(os.path.join(args.input_folder, ''), '*') | |
model.eval() | |
focal_length = args.focal_length | |
principal_point = args.principal_point | |
thres = args.threshold | |
output_dir = cfg.OUTPUT_DIR | |
min_size = cfg.INPUT.MIN_SIZE_TEST | |
max_size = cfg.INPUT.MAX_SIZE_TEST | |
augmentations = T.AugmentationList([T.ResizeShortestEdge(min_size, max_size, "choice")]) | |
util.mkdir_if_missing(output_dir) | |
category_path = os.path.join(util.file_parts(args.config_file)[0], 'category_meta.json') | |
# store locally if needed | |
if category_path.startswith(util.CubeRCNNHandler.PREFIX): | |
category_path = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, category_path) | |
metadata = util.load_json(category_path) | |
cats = metadata['thing_classes'] | |
for path in list_of_ims: | |
im_name = util.file_parts(path)[1] | |
im = util.imread(path) | |
if im is None: | |
continue | |
image_shape = im.shape[:2] # h, w | |
h, w = image_shape | |
if focal_length == 0: | |
focal_length_ndc = 4.0 | |
focal_length = focal_length_ndc * h / 2 | |
if len(principal_point) == 0: | |
px, py = w/2, h/2 | |
else: | |
px, py = principal_point | |
K = np.array([ | |
[focal_length, 0.0, px], | |
[0.0, focal_length, py], | |
[0.0, 0.0, 1.0] | |
]) | |
is_ground = os.path.exists(f'datasets/ground_maps/{im_name}.jpg.npz') | |
if is_ground: | |
ground_map = np.load(f'datasets/ground_maps/{im_name}.jpg.npz')['mask'] | |
depth_map = np.load(f'datasets/depth_maps/{im_name}.jpg.npz')['depth'] | |
aug_input = T.AugInput(im) | |
tfms = augmentations(aug_input) | |
image = aug_input.image | |
if is_ground: | |
ground_map = tfms.apply_image(ground_map*1.0) | |
ground_map = torch.as_tensor(ground_map) | |
else: | |
ground_map = None | |
depth_map = tfms.apply_image(depth_map) | |
# batched = [{ | |
# 'image': torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))).cuda(), | |
# 'height': image_shape[0], 'width': image_shape[1], 'K': K | |
# }] | |
# first you must run the scripts to get the ground and depth map for the images | |
batched = [{ | |
'image': torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))), | |
'depth_map': torch.as_tensor(depth_map), | |
'ground_map': ground_map, | |
'height': image_shape[0], 'width': image_shape[1], 'K': K | |
}] | |
dets = model(batched)[0]['instances'] | |
n_det = len(dets) | |
meshes = [] | |
meshes_text = [] | |
if n_det > 0: | |
for idx, (corners3D, center_cam, center_2D, dimensions, pose, score, cat_idx) in enumerate(zip( | |
dets.pred_bbox3D, dets.pred_center_cam, dets.pred_center_2D, dets.pred_dimensions, | |
dets.pred_pose, dets.scores, dets.pred_classes | |
)): | |
# skip | |
if score < thres: | |
continue | |
cat = cats[cat_idx] | |
bbox3D = center_cam.tolist() + dimensions.tolist() | |
meshes_text.append('{} {:.2f}'.format(cat, score)) | |
color = [c/255.0 for c in util.get_color(idx)] | |
box_mesh = util.mesh_cuboid(bbox3D, pose.tolist(), color=color) | |
meshes.append(box_mesh) | |
print('File: {} with {} dets'.format(im_name, len(meshes))) | |
if len(meshes) > 0: | |
im_drawn_rgb, im_topdown, _ = vis.draw_scene_view(im, K, meshes, text=meshes_text, scale=im.shape[0], blend_weight=0.5, blend_weight_overlay=0.85) | |
if args.display: | |
im_concat = np.concatenate((im_drawn_rgb, im_topdown), axis=1) | |
vis.imshow(im_concat) | |
util.imwrite(im_drawn_rgb, os.path.join(output_dir, im_name+'_boxes.jpg')) | |
util.imwrite(im_topdown, os.path.join(output_dir, im_name+'_novel.jpg')) | |
else: | |
util.imwrite(im, os.path.join(output_dir, im_name+'_boxes.jpg')) | |
def setup(args): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
get_cfg_defaults(cfg) | |
config_file = args.config_file | |
# store locally if needed | |
if config_file.startswith(util.CubeRCNNHandler.PREFIX): | |
config_file = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, config_file) | |
cfg.merge_from_file(config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
default_setup(cfg, args) | |
return cfg | |
def main(args): | |
cfg = setup(args) | |
model = build_model(cfg) | |
logger.info("Model:\n{}".format(model)) | |
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( | |
cfg.MODEL.WEIGHTS, resume=True | |
) | |
with torch.no_grad(): | |
do_test(args, cfg, model) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
epilog=None, formatter_class=argparse.RawDescriptionHelpFormatter, | |
) | |
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | |
parser.add_argument('--input-folder', type=str, help='list of image folders to process', required=True) | |
parser.add_argument("--focal-length", type=float, default=0, help="focal length for image inputs (in px)") | |
parser.add_argument("--principal-point", type=float, default=[], nargs=2, help="principal point for image inputs (in px)") | |
parser.add_argument("--threshold", type=float, default=0.25, help="threshold on score for visualizing") | |
parser.add_argument("--display", default=False, action="store_true", help="Whether to show the images in matplotlib",) | |
parser.add_argument("--eval-only", default=True, action="store_true", help="perform evaluation only") | |
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") | |
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") | |
parser.add_argument( | |
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" | |
) | |
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 | |
parser.add_argument( | |
"--dist-url", | |
default="tcp://127.0.0.1:{}".format(port), | |
help="initialization URL for pytorch distributed backend. See " | |
"https://pytorch.org/docs/stable/distributed.html for details.", | |
) | |
parser.add_argument( | |
"opts", | |
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. " | |
"See config references at " | |
"https://detectron2.readthedocs.io/modules/config.html#config-references", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
args = parser.parse_args() | |
print("Command Line Args:", args) | |
launch( | |
main, | |
args.num_gpus, | |
num_machines=args.num_machines, | |
machine_rank=args.machine_rank, | |
dist_url=args.dist_url, | |
args=(args,), | |
) |