Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Script to pre-process the CO3D dataset. | |
# Usage: | |
# python3 datasets_preprocess/preprocess_co3d.py --co3d_dir /path/to/co3d | |
# -------------------------------------------------------- | |
import argparse | |
import gzip | |
import json | |
import os | |
import os.path as osp | |
import random | |
import cv2 | |
import dust3r.datasets.utils.cropping as cropping # noqa | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import path_to_root # noqa | |
import PIL.Image | |
import torch | |
from tqdm.auto import tqdm | |
CATEGORIES = [ | |
"apple", | |
"backpack", | |
"ball", | |
"banana", | |
"baseballbat", | |
"baseballglove", | |
"bench", | |
"bicycle", | |
"book", | |
"bottle", | |
"bowl", | |
"broccoli", | |
"cake", | |
"car", | |
"carrot", | |
"cellphone", | |
"chair", | |
"couch", | |
"cup", | |
"donut", | |
"frisbee", | |
"hairdryer", | |
"handbag", | |
"hotdog", | |
"hydrant", | |
"keyboard", | |
"kite", | |
"laptop", | |
"microwave", | |
"motorcycle", | |
"mouse", | |
"orange", | |
"parkingmeter", | |
"pizza", | |
"plant", | |
"remote", | |
"sandwich", | |
"skateboard", | |
"stopsign", | |
"suitcase", | |
"teddybear", | |
"toaster", | |
"toilet", | |
"toybus", | |
"toyplane", | |
"toytrain", | |
"toytruck", | |
"tv", | |
"umbrella", | |
"vase", | |
"wineglass", | |
] | |
CATEGORIES_IDX = {cat: i for i, cat in enumerate(CATEGORIES)} # for seeding | |
SINGLE_SEQUENCE_CATEGORIES = sorted( | |
set(CATEGORIES) - set(["microwave", "stopsign", "tv"]) | |
) | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--category", type=str, default=None) | |
parser.add_argument( | |
"--single_sequence_subset", | |
default=False, | |
action="store_true", | |
help="prepare the single_sequence_subset instead.", | |
) | |
parser.add_argument("--output_dir", type=str, default="data/co3d_processed") | |
parser.add_argument("--co3d_dir", type=str, required=True) | |
parser.add_argument("--num_sequences_per_object", type=int, default=50) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument( | |
"--min_quality", | |
type=float, | |
default=0.5, | |
help="Minimum viewpoint quality score.", | |
) | |
parser.add_argument( | |
"--img_size", | |
type=int, | |
default=512, | |
help=( | |
"lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size" | |
), | |
) | |
return parser | |
def convert_ndc_to_pinhole(focal_length, principal_point, image_size): | |
focal_length = np.array(focal_length) | |
principal_point = np.array(principal_point) | |
image_size_wh = np.array([image_size[1], image_size[0]]) | |
half_image_size = image_size_wh / 2 | |
rescale = half_image_size.min() | |
principal_point_px = half_image_size - principal_point * rescale | |
focal_length_px = focal_length * rescale | |
fx, fy = focal_length_px[0], focal_length_px[1] | |
cx, cy = principal_point_px[0], principal_point_px[1] | |
K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32) | |
return K | |
def opencv_from_cameras_projection(R, T, focal, p0, image_size): | |
R = torch.from_numpy(R)[None, :, :] | |
T = torch.from_numpy(T)[None, :] | |
focal = torch.from_numpy(focal)[None, :] | |
p0 = torch.from_numpy(p0)[None, :] | |
image_size = torch.from_numpy(image_size)[None, :] | |
R_pytorch3d = R.clone() | |
T_pytorch3d = T.clone() | |
focal_pytorch3d = focal | |
p0_pytorch3d = p0 | |
T_pytorch3d[:, :2] *= -1 | |
R_pytorch3d[:, :, :2] *= -1 | |
tvec = T_pytorch3d | |
R = R_pytorch3d.permute(0, 2, 1) | |
# Retype the image_size correctly and flip to width, height. | |
image_size_wh = image_size.to(R).flip(dims=(1,)) | |
# NDC to screen conversion. | |
scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 | |
scale = scale.expand(-1, 2) | |
c0 = image_size_wh / 2.0 | |
principal_point = -p0_pytorch3d * scale + c0 | |
focal_length = focal_pytorch3d * scale | |
camera_matrix = torch.zeros_like(R) | |
camera_matrix[:, :2, 2] = principal_point | |
camera_matrix[:, 2, 2] = 1.0 | |
camera_matrix[:, 0, 0] = focal_length[:, 0] | |
camera_matrix[:, 1, 1] = focal_length[:, 1] | |
return R[0], tvec[0], camera_matrix[0] | |
def get_set_list(category_dir, split, is_single_sequence_subset=False): | |
listfiles = os.listdir(osp.join(category_dir, "set_lists")) | |
if is_single_sequence_subset: | |
# not all objects have manyview_dev | |
subset_list_files = [f for f in listfiles if "manyview_dev" in f] | |
else: | |
subset_list_files = [f for f in listfiles if f"fewview_train" in f] | |
sequences_all = [] | |
for subset_list_file in subset_list_files: | |
with open(osp.join(category_dir, "set_lists", subset_list_file)) as f: | |
subset_lists_data = json.load(f) | |
sequences_all.extend(subset_lists_data[split]) | |
return sequences_all | |
def prepare_sequences( | |
category, | |
co3d_dir, | |
output_dir, | |
img_size, | |
split, | |
min_quality, | |
max_num_sequences_per_object, | |
seed, | |
is_single_sequence_subset=False, | |
): | |
random.seed(seed) | |
category_dir = osp.join(co3d_dir, category) | |
category_output_dir = osp.join(output_dir, category) | |
sequences_all = get_set_list(category_dir, split, is_single_sequence_subset) | |
sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all)) | |
frame_file = osp.join(category_dir, "frame_annotations.jgz") | |
sequence_file = osp.join(category_dir, "sequence_annotations.jgz") | |
with gzip.open(frame_file, "r") as fin: | |
frame_data = json.loads(fin.read()) | |
with gzip.open(sequence_file, "r") as fin: | |
sequence_data = json.loads(fin.read()) | |
frame_data_processed = {} | |
for f_data in frame_data: | |
sequence_name = f_data["sequence_name"] | |
frame_data_processed.setdefault(sequence_name, {})[ | |
f_data["frame_number"] | |
] = f_data | |
good_quality_sequences = set() | |
for seq_data in sequence_data: | |
if seq_data["viewpoint_quality_score"] > min_quality: | |
good_quality_sequences.add(seq_data["sequence_name"]) | |
sequences_numbers = [ | |
seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences | |
] | |
if len(sequences_numbers) < max_num_sequences_per_object: | |
selected_sequences_numbers = sequences_numbers | |
else: | |
selected_sequences_numbers = random.sample( | |
sequences_numbers, max_num_sequences_per_object | |
) | |
selected_sequences_numbers_dict = { | |
seq_name: [] for seq_name in selected_sequences_numbers | |
} | |
sequences_all = [ | |
(seq_name, frame_number, filepath) | |
for seq_name, frame_number, filepath in sequences_all | |
if seq_name in selected_sequences_numbers_dict | |
] | |
for seq_name, frame_number, filepath in tqdm(sequences_all): | |
frame_idx = int(filepath.split("/")[-1][5:-4]) | |
selected_sequences_numbers_dict[seq_name].append(frame_idx) | |
mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") | |
frame_data = frame_data_processed[seq_name][frame_number] | |
focal_length = frame_data["viewpoint"]["focal_length"] | |
principal_point = frame_data["viewpoint"]["principal_point"] | |
image_size = frame_data["image"]["size"] | |
K = convert_ndc_to_pinhole(focal_length, principal_point, image_size) | |
R, tvec, camera_intrinsics = opencv_from_cameras_projection( | |
np.array(frame_data["viewpoint"]["R"]), | |
np.array(frame_data["viewpoint"]["T"]), | |
np.array(focal_length), | |
np.array(principal_point), | |
np.array(image_size), | |
) | |
frame_data = frame_data_processed[seq_name][frame_number] | |
depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"]) | |
assert frame_data["depth"]["scale_adjustment"] == 1.0 | |
image_path = os.path.join(co3d_dir, filepath) | |
mask_path_full = os.path.join(co3d_dir, mask_path) | |
input_rgb_image = PIL.Image.open(image_path).convert("RGB") | |
input_mask = plt.imread(mask_path_full) | |
with PIL.Image.open(depth_path) as depth_pil: | |
# the image is stored with 16-bit depth but PIL reads it as I (32 bit). | |
# we cast it to uint16, then reinterpret as float16, then cast to float32 | |
input_depthmap = ( | |
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) | |
.astype(np.float32) | |
.reshape((depth_pil.size[1], depth_pil.size[0])) | |
) | |
depth_mask = np.stack((input_depthmap, input_mask), axis=-1) | |
H, W = input_depthmap.shape | |
camera_intrinsics = camera_intrinsics.numpy() | |
cx, cy = camera_intrinsics[:2, 2].round().astype(int) | |
min_margin_x = min(cx, W - cx) | |
min_margin_y = min(cy, H - cy) | |
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) | |
l, t = cx - min_margin_x, cy - min_margin_y | |
r, b = cx + min_margin_x, cy + min_margin_y | |
crop_bbox = (l, t, r, b) | |
( | |
input_rgb_image, | |
depth_mask, | |
input_camera_intrinsics, | |
) = cropping.crop_image_depthmap( | |
input_rgb_image, depth_mask, camera_intrinsics, crop_bbox | |
) | |
# try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 | |
scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 | |
output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) | |
if max(output_resolution) < img_size: | |
# let's put the max dimension to img_size | |
scale_final = (img_size / max(H, W)) + 1e-8 | |
output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) | |
( | |
input_rgb_image, | |
depth_mask, | |
input_camera_intrinsics, | |
) = cropping.rescale_image_depthmap( | |
input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution | |
) | |
input_depthmap = depth_mask[:, :, 0] | |
input_mask = depth_mask[:, :, 1] | |
# generate and adjust camera pose | |
camera_pose = np.eye(4, dtype=np.float32) | |
camera_pose[:3, :3] = R | |
camera_pose[:3, 3] = tvec | |
camera_pose = np.linalg.inv(camera_pose) | |
# save crop images and depth, metadata | |
save_img_path = os.path.join(output_dir, filepath) | |
save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"]) | |
save_mask_path = os.path.join(output_dir, mask_path) | |
os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) | |
os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) | |
os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) | |
input_rgb_image.save(save_img_path) | |
scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype( | |
np.uint16 | |
) | |
cv2.imwrite(save_depth_path, scaled_depth_map) | |
cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) | |
save_meta_path = save_img_path.replace("jpg", "npz") | |
np.savez( | |
save_meta_path, | |
camera_intrinsics=input_camera_intrinsics, | |
camera_pose=camera_pose, | |
maximum_depth=np.max(input_depthmap), | |
) | |
return selected_sequences_numbers_dict | |
if __name__ == "__main__": | |
parser = get_parser() | |
args = parser.parse_args() | |
assert args.co3d_dir != args.output_dir | |
if args.category is None: | |
if args.single_sequence_subset: | |
categories = SINGLE_SEQUENCE_CATEGORIES | |
else: | |
categories = CATEGORIES | |
else: | |
categories = [args.category] | |
os.makedirs(args.output_dir, exist_ok=True) | |
for split in ["train", "test"]: | |
selected_sequences_path = os.path.join( | |
args.output_dir, f"selected_seqs_{split}.json" | |
) | |
if os.path.isfile(selected_sequences_path): | |
continue | |
all_selected_sequences = {} | |
for category in categories: | |
category_output_dir = osp.join(args.output_dir, category) | |
os.makedirs(category_output_dir, exist_ok=True) | |
category_selected_sequences_path = os.path.join( | |
category_output_dir, f"selected_seqs_{split}.json" | |
) | |
if os.path.isfile(category_selected_sequences_path): | |
with open(category_selected_sequences_path, "r") as fid: | |
category_selected_sequences = json.load(fid) | |
else: | |
print(f"Processing {split} - category = {category}") | |
category_selected_sequences = prepare_sequences( | |
category=category, | |
co3d_dir=args.co3d_dir, | |
output_dir=args.output_dir, | |
img_size=args.img_size, | |
split=split, | |
min_quality=args.min_quality, | |
max_num_sequences_per_object=args.num_sequences_per_object, | |
seed=args.seed + CATEGORIES_IDX[category], | |
is_single_sequence_subset=args.single_sequence_subset, | |
) | |
with open(category_selected_sequences_path, "w") as file: | |
json.dump(category_selected_sequences, file) | |
all_selected_sequences[category] = category_selected_sequences | |
with open(selected_sequences_path, "w") as file: | |
json.dump(all_selected_sequences, file) | |