Spaces:
Sleeping
Sleeping
import argparse | |
import numpy as np | |
import imageio | |
import torch | |
from tqdm import tqdm | |
import time | |
import scipy | |
import scipy.io | |
import scipy.misc | |
from lib.model_test import D2Net | |
from lib.utils import preprocess_image | |
from lib.pyramid import process_multiscale | |
import cv2 | |
import matplotlib.pyplot as plt | |
import os | |
from sys import exit, argv | |
from PIL import Image | |
from skimage.feature import match_descriptors | |
from skimage.measure import ransac | |
from skimage.transform import ProjectiveTransform, AffineTransform | |
import pydegensac | |
def extractSingle(image, model, device): | |
with torch.no_grad(): | |
keypoints, scores, descriptors = process_multiscale( | |
image.to(device).unsqueeze(0), | |
model, | |
scales=[1] | |
) | |
keypoints = keypoints[:, [1, 0, 2]] | |
feat = {} | |
feat['keypoints'] = keypoints | |
feat['scores'] = scores | |
feat['descriptors'] = descriptors | |
return feat | |
def siftMatching(img1, img2, HFile1, HFile2, device): | |
if HFile1 is not None: | |
H1 = np.load(HFile1) | |
H2 = np.load(HFile2) | |
rgbFile1 = img1 | |
img1 = Image.open(img1) | |
if(img1.mode != 'RGB'): | |
img1 = img1.convert('RGB') | |
img1 = np.array(img1) | |
if HFile1 is not None: | |
img1 = cv2.warpPerspective(img1, H1, dsize=(400,400)) | |
#### Visualization #### | |
# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)) | |
# cv2.waitKey(0) | |
rgbFile2 = img2 | |
img2 = Image.open(img2) | |
if(img2.mode != 'RGB'): | |
img2 = img2.convert('RGB') | |
img2 = np.array(img2) | |
if HFile2 is not None: | |
img2 = cv2.warpPerspective(img2, H2, dsize=(400,400)) | |
#### Visualization #### | |
# cv2.imshow("Image", cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)) | |
# cv2.waitKey(0) | |
# surf = cv2.xfeatures2d.SURF_create(100) # SURF | |
surf = cv2.xfeatures2d.SIFT_create() | |
kp1, des1 = surf.detectAndCompute(img1, None) | |
kp2, des2 = surf.detectAndCompute(img2, None) | |
matches = mnn_matcher( | |
torch.from_numpy(des1).float().to(device=device), | |
torch.from_numpy(des2).float().to(device=device) | |
) | |
src_pts = np.float32([ kp1[m[0]].pt for m in matches ]).reshape(-1, 2) | |
dst_pts = np.float32([ kp2[m[1]].pt for m in matches ]).reshape(-1, 2) | |
if(src_pts.shape[0] < 5 or dst_pts.shape[0] < 5): | |
return [], [] | |
H, inliers = pydegensac.findHomography(src_pts, dst_pts, 8.0, 0.99, 10000) | |
n_inliers = np.sum(inliers) | |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]] | |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]] | |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] | |
#### Visualization #### | |
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None) | |
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB) | |
# cv2.imshow('Matches', image3) | |
# cv2.waitKey() | |
src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2) | |
dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2) | |
if HFile1 is None: | |
return src_pts, dst_pts, image3, image3 | |
orgSrc, orgDst = orgKeypoints(src_pts, dst_pts, H1, H2) | |
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) | |
return orgSrc, orgDst, matchImg, image3 | |
def orgKeypoints(src_pts, dst_pts, H1, H2): | |
ones = np.ones((src_pts.shape[0], 1)) | |
src_pts = np.hstack((src_pts, ones)) | |
dst_pts = np.hstack((dst_pts, ones)) | |
orgSrc = np.linalg.inv(H1) @ src_pts.T | |
orgDst = np.linalg.inv(H2) @ dst_pts.T | |
orgSrc = orgSrc/orgSrc[2, :] | |
orgDst = orgDst/orgDst[2, :] | |
orgSrc = np.asarray(orgSrc)[0:2, :] | |
orgDst = np.asarray(orgDst)[0:2, :] | |
return orgSrc, orgDst | |
def drawOrg(image1, image2, orgSrc, orgDst): | |
img1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB) | |
img2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) | |
for i in range(orgSrc.shape[1]): | |
im1 = cv2.circle(img1, (int(orgSrc[0, i]), int(orgSrc[1, i])), 3, (0, 0, 255), 1) | |
for i in range(orgDst.shape[1]): | |
im2 = cv2.circle(img2, (int(orgDst[0, i]), int(orgDst[1, i])), 3, (0, 0, 255), 1) | |
im4 = cv2.hconcat([im1, im2]) | |
for i in range(orgSrc.shape[1]): | |
im4 = cv2.line(im4, (int(orgSrc[0, i]), int(orgSrc[1, i])), (int(orgDst[0, i]) + im1.shape[1], int(orgDst[1, i])), (0, 255, 0), 1) | |
im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2RGB) | |
# cv2.imshow("Image", im4) | |
# cv2.waitKey(0) | |
return im4 | |
def getPerspKeypoints(rgbFile1, rgbFile2, HFile1, HFile2, model, device): | |
if HFile1 is None: | |
igp1, img1 = read_and_process_image(rgbFile1, H=None) | |
else: | |
H1 = np.load(HFile1) | |
igp1, img1 = read_and_process_image(rgbFile1, H=H1) | |
c,h,w = igp1.shape | |
if HFile2 is None: | |
igp2, img2 = read_and_process_image(rgbFile2, H=None) | |
else: | |
H2 = np.load(HFile2) | |
igp2, img2 = read_and_process_image(rgbFile2, H=H2) | |
feat1 = extractSingle(igp1, model, device) | |
feat2 = extractSingle(igp2, model, device) | |
matches = mnn_matcher( | |
torch.from_numpy(feat1['descriptors']).to(device=device), | |
torch.from_numpy(feat2['descriptors']).to(device=device), | |
) | |
pos_a = feat1["keypoints"][matches[:, 0], : 2] | |
pos_b = feat2["keypoints"][matches[:, 1], : 2] | |
H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000) | |
pos_a = pos_a[inliers] | |
pos_b = pos_b[inliers] | |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a] | |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b] | |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))] | |
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0]) | |
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB) | |
#### Visualization #### | |
# cv2.imshow('Matches', image3) | |
# cv2.waitKey() | |
if HFile1 is None: | |
return pos_a, pos_b, image3, image3 | |
orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2) | |
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) # Reproject matches to perspective View | |
return orgSrc, orgDst, matchImg, image3 | |
###### Ensemble | |
def read_and_process_image(img_path, resize=None, H=None, h=None, w=None, preprocessing='caffe'): | |
img1 = Image.open(img_path) | |
if resize: | |
img1 = img1.resize(resize) | |
if(img1.mode != 'RGB'): | |
img1 = img1.convert('RGB') | |
img1 = np.array(img1) | |
if H is not None: | |
img1 = cv2.warpPerspective(img1, H, dsize=(400, 400)) | |
# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)) | |
# cv2.waitKey(0) | |
igp1 = torch.from_numpy(preprocess_image(img1, preprocessing=preprocessing).astype(np.float32)) | |
return igp1, img1 | |
def mnn_matcher_scorer(descriptors_a, descriptors_b, k=np.inf): | |
device = descriptors_a.device | |
sim = descriptors_a @ descriptors_b.t() | |
val1, nn12 = torch.max(sim, dim=1) | |
val2, nn21 = torch.max(sim, dim=0) | |
ids1 = torch.arange(0, sim.shape[0], device=device) | |
mask = (ids1 == nn21[nn12]) | |
matches = torch.stack([ids1[mask], nn12[mask]]).t() | |
remaining_matches_dist = val1[mask] | |
return matches, remaining_matches_dist | |
def mnn_matcher(descriptors_a, descriptors_b): | |
device = descriptors_a.device | |
sim = descriptors_a @ descriptors_b.t() | |
nn12 = torch.max(sim, dim=1)[1] | |
nn21 = torch.max(sim, dim=0)[1] | |
ids1 = torch.arange(0, sim.shape[0], device=device) | |
mask = (ids1 == nn21[nn12]) | |
matches = torch.stack([ids1[mask], nn12[mask]]) | |
return matches.t().data.cpu().numpy() | |
def getPerspKeypointsEnsemble(model1, model2, rgbFile1, rgbFile2, HFile1, HFile2, device): | |
if HFile1 is None: | |
igp1, img1 = read_and_process_image(rgbFile1, H=None) | |
else: | |
H1 = np.load(HFile1) | |
igp1, img1 = read_and_process_image(rgbFile1, H=H1) | |
c,h,w = igp1.shape | |
if HFile2 is None: | |
igp2, img2 = read_and_process_image(rgbFile2, H=None) | |
else: | |
H2 = np.load(HFile2) | |
igp2, img2 = read_and_process_image(rgbFile2, H=H2) | |
with torch.no_grad(): | |
keypoints_a1, scores_a1, descriptors_a1 = process_multiscale( | |
igp1.to(device).unsqueeze(0), | |
model1, | |
scales=[1] | |
) | |
keypoints_a1 = keypoints_a1[:, [1, 0, 2]] | |
keypoints_a2, scores_a2, descriptors_a2 = process_multiscale( | |
igp1.to(device).unsqueeze(0), | |
model2, | |
scales=[1] | |
) | |
keypoints_a2 = keypoints_a2[:, [1, 0, 2]] | |
keypoints_b1, scores_b1, descriptors_b1 = process_multiscale( | |
igp2.to(device).unsqueeze(0), | |
model1, | |
scales=[1] | |
) | |
keypoints_b1 = keypoints_b1[:, [1, 0, 2]] | |
keypoints_b2, scores_b2, descriptors_b2 = process_multiscale( | |
igp2.to(device).unsqueeze(0), | |
model2, | |
scales=[1] | |
) | |
keypoints_b2 = keypoints_b2[:, [1, 0, 2]] | |
# calculating matches for both models | |
matches1, dist_1 = mnn_matcher_scorer( | |
torch.from_numpy(descriptors_a1).to(device=device), | |
torch.from_numpy(descriptors_b1).to(device=device), | |
# len(matches1) | |
) | |
matches2, dist_2 = mnn_matcher_scorer( | |
torch.from_numpy(descriptors_a2).to(device=device), | |
torch.from_numpy(descriptors_b2).to(device=device), | |
# len(matches1) | |
) | |
full_matches = torch.cat([matches1, matches2]) | |
full_dist = torch.cat([dist_1, dist_2]) | |
assert len(full_dist)==(len(dist_1)+len(dist_2)), "something wrong" | |
k_final = len(full_dist)//2 | |
# k_final = len(full_dist) | |
# k_final = max(len(dist_1), len(dist_2)) | |
top_k_mask = torch.topk(full_dist, k=k_final)[1] | |
first = [] | |
second = [] | |
for valid_id in top_k_mask: | |
if valid_id<len(dist_1): | |
first.append(valid_id) | |
else: | |
second.append(valid_id-len(dist_1)) | |
# final_matches = full_matches[top_k_mask] | |
matches1 = matches1[torch.tensor(first, device=device).long()].data.cpu().numpy() | |
matches2 = matches2[torch.tensor(second, device=device).long()].data.cpu().numpy() | |
pos_a1 = keypoints_a1[matches1[:, 0], : 2] | |
pos_b1 = keypoints_b1[matches1[:, 1], : 2] | |
pos_a2 = keypoints_a2[matches2[:, 0], : 2] | |
pos_b2 = keypoints_b2[matches2[:, 1], : 2] | |
pos_a = np.concatenate([pos_a1, pos_a2], 0) | |
pos_b = np.concatenate([pos_b1, pos_b2], 0) | |
# pos_a, pos_b, inliers = apply_ransac(pos_a, pos_b) | |
H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000) | |
pos_a = pos_a[inliers] | |
pos_b = pos_b[inliers] | |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a] | |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b] | |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))] | |
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0]) | |
image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB) | |
# cv2.imshow('Matches', image3) | |
# cv2.waitKey() | |
orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2) | |
matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) | |
return orgSrc, orgDst, matchImg, image3 | |
if __name__ == '__main__': | |
WEIGHTS = '../models/rord.pth' | |
srcR = argv[1] | |
trgR = argv[2] | |
srcH = argv[3] | |
trgH = argv[4] | |
orgSrc, orgDst = getPerspKeypoints(srcR, trgR, srcH, trgH, WEIGHTS, ('gpu')) | |