Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.utils.data as data | |
import cv2 | |
import os | |
import h5py | |
import random | |
import sys | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) | |
sys.path.insert(0, ROOT_DIR) | |
from utils import train_utils, evaluation_utils | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
class Offline_Dataset(data.Dataset): | |
def __init__(self, config, mode): | |
assert mode == "train" or mode == "valid" | |
self.config = config | |
self.mode = mode | |
metadir = ( | |
os.path.join(config.dataset_path, "valid") | |
if mode == "valid" | |
else os.path.join(config.dataset_path, "train") | |
) | |
pair_num_list = np.loadtxt(os.path.join(metadir, "pair_num.txt"), dtype=str) | |
self.total_pairs = int(pair_num_list[0, 1]) | |
self.pair_seq_list, self.accu_pair_num = train_utils.parse_pair_seq( | |
pair_num_list | |
) | |
def collate_fn(self, batch): | |
batch_size, num_pts = len(batch), batch[0]["x1"].shape[0] | |
data = {} | |
dtype = [ | |
"x1", | |
"x2", | |
"kpt1", | |
"kpt2", | |
"desc1", | |
"desc2", | |
"num_corr", | |
"num_incorr1", | |
"num_incorr2", | |
"e_gt", | |
"pscore1", | |
"pscore2", | |
"img_path1", | |
"img_path2", | |
] | |
for key in dtype: | |
data[key] = [] | |
for sample in batch: | |
for key in dtype: | |
data[key].append(sample[key]) | |
for key in [ | |
"x1", | |
"x2", | |
"kpt1", | |
"kpt2", | |
"desc1", | |
"desc2", | |
"e_gt", | |
"pscore1", | |
"pscore2", | |
]: | |
data[key] = torch.from_numpy(np.stack(data[key])).float() | |
for key in ["num_corr", "num_incorr1", "num_incorr2"]: | |
data[key] = torch.from_numpy(np.stack(data[key])).int() | |
# kpt augmentation with random homography | |
if self.mode == "train" and self.config.data_aug: | |
homo_mat = torch.from_numpy( | |
train_utils.get_rnd_homography(batch_size) | |
).unsqueeze(1) | |
aug_seed = random.random() | |
if aug_seed < 0.5: | |
x1_homo = torch.cat( | |
[data["x1"], torch.ones([batch_size, num_pts, 1])], dim=-1 | |
).unsqueeze(-1) | |
x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) | |
data["aug_x1"] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) | |
data["aug_x2"] = data["x2"] | |
else: | |
x2_homo = torch.cat( | |
[data["x2"], torch.ones([batch_size, num_pts, 1])], dim=-1 | |
).unsqueeze(-1) | |
x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) | |
data["aug_x2"] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) | |
data["aug_x1"] = data["x1"] | |
else: | |
data["aug_x1"], data["aug_x2"] = data["x1"], data["x2"] | |
return data | |
def __getitem__(self, index): | |
seq = self.pair_seq_list[index] | |
index_within_seq = index - self.accu_pair_num[seq] | |
with h5py.File( | |
os.path.join(self.config.dataset_path, seq, "info.h5py"), "r" | |
) as data: | |
R, t = ( | |
data["dR"][str(index_within_seq)][()], | |
data["dt"][str(index_within_seq)][()], | |
) | |
egt = np.reshape( | |
np.matmul( | |
np.reshape( | |
evaluation_utils.np_skew_symmetric( | |
t.astype("float64").reshape(1, 3) | |
), | |
(3, 3), | |
), | |
np.reshape(R.astype("float64"), (3, 3)), | |
), | |
(3, 3), | |
) | |
egt = egt / np.linalg.norm(egt) | |
K1, K2 = ( | |
data["K1"][str(index_within_seq)][()], | |
data["K2"][str(index_within_seq)][()], | |
) | |
size1, size2 = ( | |
data["size1"][str(index_within_seq)][()], | |
data["size2"][str(index_within_seq)][()], | |
) | |
img_path1, img_path2 = ( | |
data["img_path1"][str(index_within_seq)][()][0].decode(), | |
data["img_path2"][str(index_within_seq)][()][0].decode(), | |
) | |
img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1] | |
img_path1, img_path2 = os.path.join( | |
self.config.rawdata_path, img_path1 | |
), os.path.join(self.config.rawdata_path, img_path2) | |
fea_path1, fea_path2 = os.path.join( | |
self.config.desc_path, seq, img_name1 + self.config.desc_suffix | |
), os.path.join( | |
self.config.desc_path, seq, img_name2 + self.config.desc_suffix | |
) | |
with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2: | |
desc1, kpt1, pscore1 = ( | |
fea1["descriptors"][()], | |
fea1["keypoints"][()][:, :2], | |
fea1["keypoints"][()][:, 2], | |
) | |
desc2, kpt2, pscore2 = ( | |
fea2["descriptors"][()], | |
fea2["keypoints"][()][:, :2], | |
fea2["keypoints"][()][:, 2], | |
) | |
kpt1, kpt2, desc1, desc2 = ( | |
kpt1[: self.config.num_kpt], | |
kpt2[: self.config.num_kpt], | |
desc1[: self.config.num_kpt], | |
desc2[: self.config.num_kpt], | |
) | |
# normalize kpt | |
if self.config.input_normalize == "intrinsic": | |
x1, x2 = np.concatenate( | |
[kpt1, np.ones([kpt1.shape[0], 1])], axis=-1 | |
), np.concatenate([kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) | |
x1, x2 = ( | |
np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], | |
np.matmul(np.linalg.inv(K2), x2.T).T[:, :2], | |
) | |
elif self.config.input_normalize == "img": | |
x1, x2 = (kpt1 - size1 / 2) / size1, (kpt2 - size2 / 2) / size2 | |
S1_inv, S2_inv = np.asarray( | |
[ | |
[size1[0], 0, 0.5 * size1[0]], | |
[0, size1[1], 0.5 * size1[1]], | |
[0, 0, 1], | |
] | |
), np.asarray( | |
[ | |
[size2[0], 0, 0.5 * size2[0]], | |
[0, size2[1], 0.5 * size2[1]], | |
[0, 0, 1], | |
] | |
) | |
M1, M2 = np.matmul(np.linalg.inv(K1), S1_inv), np.matmul( | |
np.linalg.inv(K2), S2_inv | |
) | |
egt = np.matmul(np.matmul(M2.transpose(), egt), M1) | |
egt = egt / np.linalg.norm(egt) | |
else: | |
raise NotImplementedError | |
corr = data["corr"][str(index_within_seq)][()] | |
incorr1, incorr2 = ( | |
data["incorr1"][str(index_within_seq)][()], | |
data["incorr2"][str(index_within_seq)][()], | |
) | |
# permute kpt | |
valid_corr = corr[corr.max(axis=-1) < self.config.num_kpt] | |
valid_incorr1, valid_incorr2 = ( | |
incorr1[incorr1 < self.config.num_kpt], | |
incorr2[incorr2 < self.config.num_kpt], | |
) | |
num_corr, num_incorr1, num_incorr2 = ( | |
len(valid_corr), | |
len(valid_incorr1), | |
len(valid_incorr2), | |
) | |
mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones( | |
x2.shape[0] | |
).astype(bool) | |
mask1_invlaid[valid_corr[:, 0]] = False | |
mask2_invalid[valid_corr[:, 1]] = False | |
mask1_invlaid[valid_incorr1] = False | |
mask2_invalid[valid_incorr2] = False | |
invalid_index1, invalid_index2 = ( | |
np.nonzero(mask1_invlaid)[0], | |
np.nonzero(mask2_invalid)[0], | |
) | |
# random sample from point w/o valid annotation | |
cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1 | |
cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2 | |
if invalid_index1.shape[0] < cur_kpt1: | |
sub_idx1 = np.concatenate( | |
[ | |
np.arange(len(invalid_index1)), | |
np.random.randint( | |
len(invalid_index1), size=cur_kpt1 - len(invalid_index1) | |
), | |
] | |
) | |
if invalid_index1.shape[0] >= cur_kpt1: | |
sub_idx1 = np.random.choice(len(invalid_index1), cur_kpt1, replace=False) | |
if invalid_index2.shape[0] < cur_kpt2: | |
sub_idx2 = np.concatenate( | |
[ | |
np.arange(len(invalid_index2)), | |
np.random.randint( | |
len(invalid_index2), size=cur_kpt2 - len(invalid_index2) | |
), | |
] | |
) | |
if invalid_index2.shape[0] >= cur_kpt2: | |
sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2, replace=False) | |
per_idx1, per_idx2 = np.concatenate( | |
[valid_corr[:, 0], valid_incorr1, invalid_index1[sub_idx1]] | |
), np.concatenate([valid_corr[:, 1], valid_incorr2, invalid_index2[sub_idx2]]) | |
pscore1, pscore2 = ( | |
pscore1[per_idx1][:, np.newaxis], | |
pscore2[per_idx2][:, np.newaxis], | |
) | |
x1, x2 = x1[per_idx1][:, :2], x2[per_idx2][:, :2] | |
desc1, desc2 = desc1[per_idx1], desc2[per_idx2] | |
kpt1, kpt2 = kpt1[per_idx1], kpt2[per_idx2] | |
return { | |
"x1": x1, | |
"x2": x2, | |
"kpt1": kpt1, | |
"kpt2": kpt2, | |
"desc1": desc1, | |
"desc2": desc2, | |
"num_corr": num_corr, | |
"num_incorr1": num_incorr1, | |
"num_incorr2": num_incorr2, | |
"e_gt": egt, | |
"pscore1": pscore1, | |
"pscore2": pscore2, | |
"img_path1": img_path1, | |
"img_path2": img_path2, | |
} | |
def __len__(self): | |
return self.total_pairs | |