Spaces:
Sleeping
Sleeping
# Copyright 2020 Toyota Research Institute. All rights reserved. | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from tqdm import tqdm | |
from evaluation.descriptor_evaluation import compute_homography, compute_matching_score | |
from evaluation.detector_evaluation import compute_repeatability | |
def evaluate_keypoint_net( | |
data_loader, keypoint_net, output_shape=(320, 240), top_k=300 | |
): | |
"""Keypoint net evaluation script. | |
Parameters | |
---------- | |
data_loader: torch.utils.data.DataLoader | |
Dataset loader. | |
keypoint_net: torch.nn.module | |
Keypoint network. | |
output_shape: tuple | |
Original image shape. | |
top_k: int | |
Number of keypoints to use to compute metrics, selected based on probability. | |
use_color: bool | |
Use color or grayscale images. | |
""" | |
keypoint_net.eval() | |
keypoint_net.training = False | |
conf_threshold = 0.0 | |
localization_err, repeatability = [], [] | |
correctness1, correctness3, correctness5, MScore = [], [], [], [] | |
with torch.no_grad(): | |
for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"): | |
image = sample["image"].cuda() | |
warped_image = sample["warped_image"].cuda() | |
score_1, coord_1, desc1 = keypoint_net(image) | |
score_2, coord_2, desc2 = keypoint_net(warped_image) | |
B, _, Hc, Wc = desc1.shape | |
# Scores & Descriptors | |
score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy() | |
score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy() | |
desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() | |
desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy() | |
# Filter based on confidence threshold | |
desc1 = desc1[score_1[:, 2] > conf_threshold, :] | |
desc2 = desc2[score_2[:, 2] > conf_threshold, :] | |
score_1 = score_1[score_1[:, 2] > conf_threshold, :] | |
score_2 = score_2[score_2[:, 2] > conf_threshold, :] | |
# Prepare data for eval | |
data = { | |
"image": sample["image"].numpy().squeeze(), | |
"image_shape": output_shape[::-1], | |
"warped_image": sample["warped_image"].numpy().squeeze(), | |
"homography": sample["homography"].squeeze().numpy(), | |
"prob": score_1, | |
"warped_prob": score_2, | |
"desc": desc1, | |
"warped_desc": desc2, | |
} | |
# Compute repeatabilty and localization error | |
_, _, rep, loc_err = compute_repeatability( | |
data, keep_k_points=top_k, distance_thresh=3 | |
) | |
repeatability.append(rep) | |
localization_err.append(loc_err) | |
# Compute correctness | |
c1, c2, c3 = compute_homography(data, keep_k_points=top_k) | |
correctness1.append(c1) | |
correctness3.append(c2) | |
correctness5.append(c3) | |
# Compute matching score | |
mscore = compute_matching_score(data, keep_k_points=top_k) | |
MScore.append(mscore) | |
return ( | |
np.mean(repeatability), | |
np.mean(localization_err), | |
np.mean(correctness1), | |
np.mean(correctness3), | |
np.mean(correctness5), | |
np.mean(MScore), | |
) | |