Vincentqyw
add: rord libs
2c8b554
import argparse
import numpy as np
import imageio
import torch
from tqdm import tqdm
import time
import scipy
import scipy.io
import scipy.misc
from lib.model_test import D2Net
from lib.utils import preprocess_image
from lib.pyramid import process_multiscale
import cv2
import matplotlib.pyplot as plt
import os
from sys import exit, argv
from PIL import Image
from skimage.feature import match_descriptors
from skimage.measure import ransac
from skimage.transform import ProjectiveTransform, AffineTransform
import pydegensac
def extractSingle(image, model, device):
with torch.no_grad():
keypoints, scores, descriptors = process_multiscale(
image.to(device).unsqueeze(0),
model,
scales=[1]
)
keypoints = keypoints[:, [1, 0, 2]]
feat = {}
feat['keypoints'] = keypoints
feat['scores'] = scores
feat['descriptors'] = descriptors
return feat
def siftMatching(img1, img2, HFile1, HFile2, device):
if HFile1 is not None:
H1 = np.load(HFile1)
H2 = np.load(HFile2)
rgbFile1 = img1
img1 = Image.open(img1)
if(img1.mode != 'RGB'):
img1 = img1.convert('RGB')
img1 = np.array(img1)
if HFile1 is not None:
img1 = cv2.warpPerspective(img1, H1, dsize=(400,400))
#### Visualization ####
# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
# cv2.waitKey(0)
rgbFile2 = img2
img2 = Image.open(img2)
if(img2.mode != 'RGB'):
img2 = img2.convert('RGB')
img2 = np.array(img2)
if HFile2 is not None:
img2 = cv2.warpPerspective(img2, H2, dsize=(400,400))
#### Visualization ####
# cv2.imshow("Image", cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
# cv2.waitKey(0)
# surf = cv2.xfeatures2d.SURF_create(100) # SURF
surf = cv2.xfeatures2d.SIFT_create()
kp1, des1 = surf.detectAndCompute(img1, None)
kp2, des2 = surf.detectAndCompute(img2, None)
matches = mnn_matcher(
torch.from_numpy(des1).float().to(device=device),
torch.from_numpy(des2).float().to(device=device)
)
src_pts = np.float32([ kp1[m[0]].pt for m in matches ]).reshape(-1, 2)
dst_pts = np.float32([ kp2[m[1]].pt for m in matches ]).reshape(-1, 2)
if(src_pts.shape[0] < 5 or dst_pts.shape[0] < 5):
return [], []
H, inliers = pydegensac.findHomography(src_pts, dst_pts, 8.0, 0.99, 10000)
n_inliers = np.sum(inliers)
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]]
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]]
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]
#### Visualization ####
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None)
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
# cv2.imshow('Matches', image3)
# cv2.waitKey()
src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
if HFile1 is None:
return src_pts, dst_pts, image3, image3
orgSrc, orgDst = orgKeypoints(src_pts, dst_pts, H1, H2)
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
return orgSrc, orgDst, matchImg, image3
def orgKeypoints(src_pts, dst_pts, H1, H2):
ones = np.ones((src_pts.shape[0], 1))
src_pts = np.hstack((src_pts, ones))
dst_pts = np.hstack((dst_pts, ones))
orgSrc = np.linalg.inv(H1) @ src_pts.T
orgDst = np.linalg.inv(H2) @ dst_pts.T
orgSrc = orgSrc/orgSrc[2, :]
orgDst = orgDst/orgDst[2, :]
orgSrc = np.asarray(orgSrc)[0:2, :]
orgDst = np.asarray(orgDst)[0:2, :]
return orgSrc, orgDst
def drawOrg(image1, image2, orgSrc, orgDst):
img1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
for i in range(orgSrc.shape[1]):
im1 = cv2.circle(img1, (int(orgSrc[0, i]), int(orgSrc[1, i])), 3, (0, 0, 255), 1)
for i in range(orgDst.shape[1]):
im2 = cv2.circle(img2, (int(orgDst[0, i]), int(orgDst[1, i])), 3, (0, 0, 255), 1)
im4 = cv2.hconcat([im1, im2])
for i in range(orgSrc.shape[1]):
im4 = cv2.line(im4, (int(orgSrc[0, i]), int(orgSrc[1, i])), (int(orgDst[0, i]) + im1.shape[1], int(orgDst[1, i])), (0, 255, 0), 1)
im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2RGB)
# cv2.imshow("Image", im4)
# cv2.waitKey(0)
return im4
def getPerspKeypoints(rgbFile1, rgbFile2, HFile1, HFile2, model, device):
if HFile1 is None:
igp1, img1 = read_and_process_image(rgbFile1, H=None)
else:
H1 = np.load(HFile1)
igp1, img1 = read_and_process_image(rgbFile1, H=H1)
c,h,w = igp1.shape
if HFile2 is None:
igp2, img2 = read_and_process_image(rgbFile2, H=None)
else:
H2 = np.load(HFile2)
igp2, img2 = read_and_process_image(rgbFile2, H=H2)
feat1 = extractSingle(igp1, model, device)
feat2 = extractSingle(igp2, model, device)
matches = mnn_matcher(
torch.from_numpy(feat1['descriptors']).to(device=device),
torch.from_numpy(feat2['descriptors']).to(device=device),
)
pos_a = feat1["keypoints"][matches[:, 0], : 2]
pos_b = feat2["keypoints"][matches[:, 1], : 2]
H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
pos_a = pos_a[inliers]
pos_b = pos_b[inliers]
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
#### Visualization ####
# cv2.imshow('Matches', image3)
# cv2.waitKey()
if HFile1 is None:
return pos_a, pos_b, image3, image3
orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) # Reproject matches to perspective View
return orgSrc, orgDst, matchImg, image3
###### Ensemble
def read_and_process_image(img_path, resize=None, H=None, h=None, w=None, preprocessing='caffe'):
img1 = Image.open(img_path)
if resize:
img1 = img1.resize(resize)
if(img1.mode != 'RGB'):
img1 = img1.convert('RGB')
img1 = np.array(img1)
if H is not None:
img1 = cv2.warpPerspective(img1, H, dsize=(400, 400))
# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
# cv2.waitKey(0)
igp1 = torch.from_numpy(preprocess_image(img1, preprocessing=preprocessing).astype(np.float32))
return igp1, img1
def mnn_matcher_scorer(descriptors_a, descriptors_b, k=np.inf):
device = descriptors_a.device
sim = descriptors_a @ descriptors_b.t()
val1, nn12 = torch.max(sim, dim=1)
val2, nn21 = torch.max(sim, dim=0)
ids1 = torch.arange(0, sim.shape[0], device=device)
mask = (ids1 == nn21[nn12])
matches = torch.stack([ids1[mask], nn12[mask]]).t()
remaining_matches_dist = val1[mask]
return matches, remaining_matches_dist
def mnn_matcher(descriptors_a, descriptors_b):
device = descriptors_a.device
sim = descriptors_a @ descriptors_b.t()
nn12 = torch.max(sim, dim=1)[1]
nn21 = torch.max(sim, dim=0)[1]
ids1 = torch.arange(0, sim.shape[0], device=device)
mask = (ids1 == nn21[nn12])
matches = torch.stack([ids1[mask], nn12[mask]])
return matches.t().data.cpu().numpy()
def getPerspKeypointsEnsemble(model1, model2, rgbFile1, rgbFile2, HFile1, HFile2, device):
if HFile1 is None:
igp1, img1 = read_and_process_image(rgbFile1, H=None)
else:
H1 = np.load(HFile1)
igp1, img1 = read_and_process_image(rgbFile1, H=H1)
c,h,w = igp1.shape
if HFile2 is None:
igp2, img2 = read_and_process_image(rgbFile2, H=None)
else:
H2 = np.load(HFile2)
igp2, img2 = read_and_process_image(rgbFile2, H=H2)
with torch.no_grad():
keypoints_a1, scores_a1, descriptors_a1 = process_multiscale(
igp1.to(device).unsqueeze(0),
model1,
scales=[1]
)
keypoints_a1 = keypoints_a1[:, [1, 0, 2]]
keypoints_a2, scores_a2, descriptors_a2 = process_multiscale(
igp1.to(device).unsqueeze(0),
model2,
scales=[1]
)
keypoints_a2 = keypoints_a2[:, [1, 0, 2]]
keypoints_b1, scores_b1, descriptors_b1 = process_multiscale(
igp2.to(device).unsqueeze(0),
model1,
scales=[1]
)
keypoints_b1 = keypoints_b1[:, [1, 0, 2]]
keypoints_b2, scores_b2, descriptors_b2 = process_multiscale(
igp2.to(device).unsqueeze(0),
model2,
scales=[1]
)
keypoints_b2 = keypoints_b2[:, [1, 0, 2]]
# calculating matches for both models
matches1, dist_1 = mnn_matcher_scorer(
torch.from_numpy(descriptors_a1).to(device=device),
torch.from_numpy(descriptors_b1).to(device=device),
# len(matches1)
)
matches2, dist_2 = mnn_matcher_scorer(
torch.from_numpy(descriptors_a2).to(device=device),
torch.from_numpy(descriptors_b2).to(device=device),
# len(matches1)
)
full_matches = torch.cat([matches1, matches2])
full_dist = torch.cat([dist_1, dist_2])
assert len(full_dist)==(len(dist_1)+len(dist_2)), "something wrong"
k_final = len(full_dist)//2
# k_final = len(full_dist)
# k_final = max(len(dist_1), len(dist_2))
top_k_mask = torch.topk(full_dist, k=k_final)[1]
first = []
second = []
for valid_id in top_k_mask:
if valid_id<len(dist_1):
first.append(valid_id)
else:
second.append(valid_id-len(dist_1))
# final_matches = full_matches[top_k_mask]
matches1 = matches1[torch.tensor(first, device=device).long()].data.cpu().numpy()
matches2 = matches2[torch.tensor(second, device=device).long()].data.cpu().numpy()
pos_a1 = keypoints_a1[matches1[:, 0], : 2]
pos_b1 = keypoints_b1[matches1[:, 1], : 2]
pos_a2 = keypoints_a2[matches2[:, 0], : 2]
pos_b2 = keypoints_b2[matches2[:, 1], : 2]
pos_a = np.concatenate([pos_a1, pos_a2], 0)
pos_b = np.concatenate([pos_b1, pos_b2], 0)
# pos_a, pos_b, inliers = apply_ransac(pos_a, pos_b)
H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
pos_a = pos_a[inliers]
pos_b = pos_b[inliers]
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
# cv2.imshow('Matches', image3)
# cv2.waitKey()
orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
return orgSrc, orgDst, matchImg, image3
if __name__ == '__main__':
WEIGHTS = '../models/rord.pth'
srcR = argv[1]
trgR = argv[2]
srcH = argv[3]
trgH = argv[4]
orgSrc, orgDst = getPerspKeypoints(srcR, trgR, srcH, trgH, WEIGHTS, ('gpu'))