Spaces:
Sleeping
Sleeping
import numpy as np | |
import argparse | |
import copy | |
import os, sys | |
import open3d as o3d | |
from sys import argv, exit | |
from PIL import Image | |
import math | |
from tqdm import tqdm | |
import cv2 | |
sys.path.append("../../") | |
from lib.extractMatchTop import getPerspKeypoints, getPerspKeypointsEnsemble, siftMatching | |
import pandas as pd | |
import torch | |
from lib.model_test import D2Net | |
#### Cuda #### | |
use_cuda = torch.cuda.is_available() | |
device = torch.device('cuda:0' if use_cuda else 'cpu') | |
#### Argument Parsing #### | |
parser = argparse.ArgumentParser(description='RoRD ICP evaluation on a DiverseView dataset sequence.') | |
parser.add_argument('--dataset', type=str, default='/scratch/udit/realsense/RoRD_data/preprocessed/', | |
help='path to the dataset folder') | |
parser.add_argument('--sequence', type=str, default='data1') | |
parser.add_argument( | |
'--output_dir', type=str, default='out', | |
help='output directory for RT estimates' | |
) | |
parser.add_argument( | |
'--model_rord', type=str, help='path to the RoRD model for evaluation' | |
) | |
parser.add_argument( | |
'--model_d2', type=str, help='path to the vanilla D2-Net model for evaluation' | |
) | |
parser.add_argument( | |
'--model_ens', action='store_true', | |
help='ensemble model of RoRD + D2-Net' | |
) | |
parser.add_argument( | |
'--sift', action='store_true', | |
help='Sift' | |
) | |
parser.add_argument( | |
'--viz3d', action='store_true', | |
help='visualize the pointcloud registrations' | |
) | |
parser.add_argument( | |
'--log_interval', type=int, default=9, | |
help='Matched image logging interval' | |
) | |
parser.add_argument( | |
'--camera_file', type=str, default='../../configs/camera.txt', | |
help='path to the camera intrinsics file. In order: focal_x, focal_y, center_x, center_y, scaling_factor.' | |
) | |
parser.add_argument( | |
'--persp', action='store_true', default=False, | |
help='Feature matching on perspective images.' | |
) | |
parser.set_defaults(fp16=False) | |
args = parser.parse_args() | |
if args.model_ens: # Change default paths accordingly for ensemble | |
model1_ens = '../../models/rord.pth' | |
model2_ens = '../../models/d2net.pth' | |
def draw_registration_result(source, target, transformation): | |
source_temp = copy.deepcopy(source) | |
target_temp = copy.deepcopy(target) | |
source_temp.transform(transformation) | |
trgSph.append(source_temp); trgSph.append(target_temp) | |
axis1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) | |
axis2 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) | |
axis2.transform(transformation) | |
trgSph.append(axis1); trgSph.append(axis2) | |
o3d.visualization.draw_geometries(trgSph) | |
def readDepth(depthFile): | |
depth = Image.open(depthFile) | |
if depth.mode != "I": | |
raise Exception("Depth image is not in intensity format") | |
return np.asarray(depth) | |
def readCamera(camera): | |
with open (camera, "rt") as file: | |
contents = file.read().split() | |
focalX = float(contents[0]) | |
focalY = float(contents[1]) | |
centerX = float(contents[2]) | |
centerY = float(contents[3]) | |
scalingFactor = float(contents[4]) | |
return focalX, focalY, centerX, centerY, scalingFactor | |
def getPointCloud(rgbFile, depthFile, pts): | |
thresh = 15.0 | |
depth = readDepth(depthFile) | |
rgb = Image.open(rgbFile) | |
points = [] | |
colors = [] | |
corIdx = [-1]*len(pts) | |
corPts = [None]*len(pts) | |
ptIdx = 0 | |
for v in range(depth.shape[0]): | |
for u in range(depth.shape[1]): | |
Z = depth[v, u] / scalingFactor | |
if Z==0: continue | |
if (Z > thresh): continue | |
X = (u - centerX) * Z / focalX | |
Y = (v - centerY) * Z / focalY | |
points.append((X, Y, Z)) | |
colors.append(rgb.getpixel((u, v))) | |
if((u, v) in pts): | |
index = pts.index((u, v)) | |
corIdx[index] = ptIdx | |
corPts[index] = (X, Y, Z) | |
ptIdx = ptIdx+1 | |
points = np.asarray(points) | |
colors = np.asarray(colors) | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(points) | |
pcd.colors = o3d.utility.Vector3dVector(colors/255) | |
return pcd, corIdx, corPts | |
def convertPts(A): | |
X = A[0]; Y = A[1] | |
x = []; y = [] | |
for i in range(len(X)): | |
x.append(int(float(X[i]))) | |
for i in range(len(Y)): | |
y.append(int(float(Y[i]))) | |
pts = [] | |
for i in range(len(x)): | |
pts.append((x[i], y[i])) | |
return pts | |
def getSphere(pts): | |
sphs = [] | |
for element in pts: | |
if(element is not None): | |
sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.03) | |
sphere.paint_uniform_color([0.9, 0.2, 0]) | |
trans = np.identity(4) | |
trans[0, 3] = element[0] | |
trans[1, 3] = element[1] | |
trans[2, 3] = element[2] | |
sphere.transform(trans) | |
sphs.append(sphere) | |
return sphs | |
def get3dCor(src, trg): | |
corr = [] | |
for sId, tId in zip(src, trg): | |
if(sId != -1 and tId != -1): | |
corr.append((sId, tId)) | |
corr = np.asarray(corr) | |
return corr | |
if __name__ == "__main__": | |
camera_file = args.camera_file | |
rgb_csv = args.dataset + args.sequence + '/rtImagesRgb.csv' | |
depth_csv = args.dataset + args.sequence + '/rtImagesDepth.csv' | |
os.makedirs(os.path.join(args.output_dir, 'vis'), exist_ok=True) | |
dir_name = args.output_dir | |
os.makedirs(args.output_dir, exist_ok=True) | |
focalX, focalY, centerX, centerY, scalingFactor = readCamera(camera_file) | |
df_rgb = pd.read_csv(rgb_csv) | |
df_dep = pd.read_csv(depth_csv) | |
model1 = D2Net(model_file=args.model_d2).to(device) | |
model2 = D2Net(model_file=args.model_rord).to(device) | |
queryId = 0 | |
for im_q, dep_q in tqdm(zip(df_rgb['query'], df_dep['query']), total=df_rgb.shape[0]): | |
filter_list = [] | |
dbId = 0 | |
for im_d, dep_d in tqdm(zip(df_rgb.iteritems(), df_dep.iteritems()), total=df_rgb.shape[1]): | |
if im_d[0] == 'query': | |
continue | |
rgb_name_src = os.path.basename(im_q) | |
H_name_src = os.path.splitext(rgb_name_src)[0] + '.npy' | |
srcH = args.dataset + args.sequence + '/rgb/' + H_name_src | |
rgb_name_trg = os.path.basename(im_d[1][1]) | |
H_name_trg = os.path.splitext(rgb_name_trg)[0] + '.npy' | |
trgH = args.dataset + args.sequence + '/rgb/' + H_name_trg | |
srcImg = srcH.replace('.npy', '.jpg') | |
trgImg = trgH.replace('.npy', '.jpg') | |
if args.model_rord: | |
if args.persp: | |
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device) | |
else: | |
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model2, device) | |
elif args.model_d2: | |
if args.persp: | |
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, HFile1=None, HFile2=None, model=model2, device=device) | |
else: | |
srcPts, trgPts, matchImg, _ = getPerspKeypoints(srcImg, trgImg, srcH, trgH, model1, device) | |
elif args.model_ens: | |
model1 = D2Net(model_file=model1_ens) | |
model1 = model1.to(device) | |
model2 = D2Net(model_file=model2_ens) | |
model2 = model2.to(device) | |
srcPts, trgPts, matchImg = getPerspKeypointsEnsemble(model1, model2, srcImg, trgImg, srcH, trgH, device) | |
elif args.sift: | |
if args.persp: | |
srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, HFile1=None, HFile2=None, device=device) | |
else: | |
srcPts, trgPts, matchImg, _ = siftMatching(srcImg, trgImg, srcH, trgH, device) | |
if(isinstance(srcPts, list) == True): | |
print(np.identity(4)) | |
filter_list.append(np.identity(4)) | |
continue | |
srcPts = convertPts(srcPts) | |
trgPts = convertPts(trgPts) | |
depth_name_src = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_q | |
depth_name_trg = os.path.dirname(os.path.dirname(args.dataset)) + '/' + dep_d[1][1] | |
srcCld, srcIdx, srcCor = getPointCloud(srcImg, depth_name_src, srcPts) | |
trgCld, trgIdx, trgCor = getPointCloud(trgImg, depth_name_trg, trgPts) | |
srcSph = getSphere(srcCor) | |
trgSph = getSphere(trgCor) | |
axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0, 0, 0]) | |
srcSph.append(srcCld); srcSph.append(axis) | |
trgSph.append(trgCld); trgSph.append(axis) | |
corr = get3dCor(srcIdx, trgIdx) | |
p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint() | |
trans_init = p2p.compute_transformation(srcCld, trgCld, o3d.utility.Vector2iVector(corr)) | |
# print(trans_init) | |
filter_list.append(trans_init) | |
if args.viz3d: | |
o3d.visualization.draw_geometries(srcSph) | |
o3d.visualization.draw_geometries(trgSph) | |
draw_registration_result(srcCld, trgCld, trans_init) | |
if(dbId%args.log_interval == 0): | |
cv2.imwrite(os.path.join(args.output_dir, 'vis') + "/matchImg.%02d.%02d.jpg"%(queryId, dbId//args.log_interval), matchImg) | |
dbId += 1 | |
RT = np.stack(filter_list).transpose(1,2,0) | |
np.save(os.path.join(dir_name, str(queryId) + '.npy'), RT) | |
queryId += 1 | |
print('-----check-------', RT.shape) | |