import argparse |
import numpy as np |
import imageio |
import torch |
from tqdm import tqdm |
import time |
import scipy |
import scipy.io |
import scipy.misc |
import os |
import sys |
from lib.model_test import D2Net |
from lib.utils import preprocess_image |
from lib.pyramid import process_multiscale |
import cv2 |
import matplotlib.pyplot as plt |
from PIL import Image |
from skimage.feature import match_descriptors |
from skimage.measure import ransac |
from skimage.transform import ProjectiveTransform, AffineTransform |
import pydegensac |
parser = argparse.ArgumentParser(description='Feature extraction script') |
parser.add_argument('imgs', type=str, nargs=2) |
parser.add_argument( |
'--preprocessing', type=str, default='caffe', |
help='image preprocessing (caffe or torch)' |
) |
parser.add_argument( |
'--model_file', type=str, |
help='path to the full model' |
) |
parser.add_argument( |
'--no-relu', dest='use_relu', action='store_false', |
help='remove ReLU after the dense feature extraction module' |
) |
parser.set_defaults(use_relu=True) |
parser.add_argument( |
'--sift', dest='use_sift', action='store_true', |
help='Show sift matching as well' |
) |
parser.set_defaults(use_sift=False) |
def extract(image, args, model, device): |
if len(image.shape) == 2: |
image = image[:, :, np.newaxis] |
image = np.repeat(image, 3, -1) |
input_image = preprocess_image( |
image, |
preprocessing=args.preprocessing |
) |
with torch.no_grad(): |
keypoints, scores, descriptors = process_multiscale( |
torch.tensor( |
input_image[np.newaxis, :, :, :].astype(np.float32), |
device=device |
), |
model, |
scales=[1] |
) |
keypoints = keypoints[:, [1, 0, 2]] |
feat = {} |
feat['keypoints'] = keypoints |
feat['scores'] = scores |
feat['descriptors'] = descriptors |
return feat |
def rordMatching(image1, image2, feat1, feat2, matcher="BF"): |
if(matcher == "BF"): |
t0 = time.time() |
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) |
matches = bf.match(feat1['descriptors'], feat2['descriptors']) |
matches = sorted(matches, key=lambda x:x.distance) |
t1 = time.time() |
print("Time to extract matches: ", t1-t0) |
print("Number of raw matches:", len(matches)) |
match1 = [m.queryIdx for m in matches] |
match2 = [m.trainIdx for m in matches] |
keypoints_left = feat1['keypoints'][match1, : 2] |
keypoints_right = feat2['keypoints'][match2, : 2] |
np.random.seed(0) |
t0 = time.time() |
H, inliers = pydegensac.findHomography(keypoints_left, keypoints_right, 10.0, 0.99, 10000) |
t1 = time.time() |
print("Time for ransac: ", t1-t0) |
n_inliers = np.sum(inliers) |
print('Number of inliers: %d.' % n_inliers) |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_left[inliers]] |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_right[inliers]] |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] |
draw_params = dict(matchColor = (0,255,0), |
singlePointColor = (255,0,0), |
flags = 0) |
image3 = cv2.drawMatches(image1, inlier_keypoints_left, image2, inlier_keypoints_right, placeholder_matches, None, **draw_params) |
plt.figure(figsize=(20, 20)) |
plt.imshow(image3) |
plt.axis('off') |
plt.show() |
def siftMatching(img1, img2): |
img1 = np.array(cv2.cvtColor(np.array(img1), cv2.COLOR_BGR2RGB)) |
img2 = np.array(cv2.cvtColor(np.array(img2), cv2.COLOR_BGR2RGB)) |
surf = cv2.xfeatures2d.SIFT_create() |
kp1, des1 = surf.detectAndCompute(img1, None) |
kp2, des2 = surf.detectAndCompute(img2, None) |
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5) |
search_params = dict(checks = 50) |
flann = cv2.FlannBasedMatcher(index_params, search_params) |
matches = flann.knnMatch(des1,des2,k=2) |
good = [] |
for m, n in matches: |
if m.distance < 0.7*n.distance: |
good.append(m) |
src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1, 2) |
dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1, 2) |
model, inliers = pydegensac.findHomography(src_pts, dst_pts, 10.0, 0.99, 10000) |
n_inliers = np.sum(inliers) |
print('Number of inliers: %d.' % n_inliers) |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]] |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]] |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] |
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None) |
cv2.imshow('Matches', image3) |
cv2.waitKey(0) |
src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2) |
dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2) |
return src_pts, dst_pts |
if __name__ == '__main__': |
use_cuda = torch.cuda.is_available() |
device = torch.device("cuda:0" if use_cuda else "cpu") |
args = parser.parse_args() |
model = D2Net( |
model_file=args.model_file, |
use_relu=args.use_relu, |
use_cuda=use_cuda |
) |
image1 = np.array(Image.open(args.imgs[0])) |
image2 = np.array(Image.open(args.imgs[1])) |
print('--\nRoRD\n--') |
feat1 = extract(image1, args, model, device) |
feat2 = extract(image2, args, model, device) |
print("Features extracted.") |
rordMatching(image1, image2, feat1, feat2, matcher="BF") |
if(args.use_sift): |
print('--\nSIFT\n--') |
siftMatching(image1, image2) |