Spaces:
Sleeping
Sleeping
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import cv2 | |
from sys import exit | |
import torch | |
import torch.nn.functional as F | |
from lib.utils import ( | |
grid_positions, | |
upscale_positions, | |
downscale_positions, | |
savefig, | |
imshow_image | |
) | |
from lib.exceptions import NoGradientError, EmptyTensorError | |
matplotlib.use('Agg') | |
def loss_function( | |
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None | |
): | |
output = model({ | |
'image1': batch['image1'].to(device), | |
'image2': batch['image2'].to(device) | |
}) | |
loss = torch.tensor(np.array([0], dtype=np.float32), device=device) | |
has_grad = False | |
n_valid_samples = 0 | |
for idx_in_batch in range(batch['image1'].size(0)): | |
# Network output | |
dense_features1 = output['dense_features1'][idx_in_batch] | |
c, h1, w1 = dense_features1.size() | |
scores1 = output['scores1'][idx_in_batch].view(-1) | |
dense_features2 = output['dense_features2'][idx_in_batch] | |
_, h2, w2 = dense_features2.size() | |
scores2 = output['scores2'][idx_in_batch] | |
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0) | |
descriptors1 = all_descriptors1 | |
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0) | |
fmap_pos1 = grid_positions(h1, w1, device) | |
pos1 = batch['pos1'][idx_in_batch].to(device) | |
pos2 = batch['pos2'][idx_in_batch].to(device) | |
ids = idsAlign(pos1, device, h1, w1) | |
fmap_pos1 = fmap_pos1[:, ids] | |
descriptors1 = descriptors1[:, ids] | |
scores1 = scores1[ids] | |
# Skip the pair if not enough GT correspondences are available | |
if ids.size(0) < 128: | |
continue | |
# Descriptors at the corresponding positions | |
fmap_pos2 = torch.round( | |
downscale_positions(pos2, scaling_steps=scaling_steps) | |
).long() | |
descriptors2 = F.normalize( | |
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]], | |
dim=0 | |
) | |
positive_distance = 2 - 2 * ( | |
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2) | |
).squeeze() | |
all_fmap_pos2 = grid_positions(h2, w2, device) | |
position_distance = torch.max( | |
torch.abs( | |
fmap_pos2.unsqueeze(2).float() - | |
all_fmap_pos2.unsqueeze(1) | |
), | |
dim=0 | |
)[0] | |
is_out_of_safe_radius = position_distance > safe_radius | |
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2) | |
negative_distance2 = torch.min( | |
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., | |
dim=1 | |
)[0] | |
all_fmap_pos1 = grid_positions(h1, w1, device) | |
position_distance = torch.max( | |
torch.abs( | |
fmap_pos1.unsqueeze(2).float() - | |
all_fmap_pos1.unsqueeze(1) | |
), | |
dim=0 | |
)[0] | |
is_out_of_safe_radius = position_distance > safe_radius | |
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1) | |
negative_distance1 = torch.min( | |
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., | |
dim=1 | |
)[0] | |
diff = positive_distance - torch.min( | |
negative_distance1, negative_distance2 | |
) | |
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]] | |
loss = loss + ( | |
torch.sum(scores1 * scores2 * F.relu(margin + diff)) / | |
(torch.sum(scores1 * scores2) ) | |
) | |
has_grad = True | |
n_valid_samples += 1 | |
if plot and batch['batch_idx'] % batch['log_interval'] == 0: | |
drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True, plot_path=plot_path) | |
if not has_grad: | |
raise NoGradientError | |
loss = loss / (n_valid_samples ) | |
return loss | |
def idsAlign(pos1, device, h1, w1): | |
pos1D = downscale_positions(pos1, scaling_steps=3) | |
row = pos1D[0, :] | |
col = pos1D[1, :] | |
ids = [] | |
for i in range(row.shape[0]): | |
index = ((w1) * (row[i])) + (col[i]) | |
ids.append(index) | |
ids = torch.round(torch.Tensor(ids)).long().to(device) | |
return ids | |
def drawTraining(image1, image2, pos1, pos2, batch, idx_in_batch, output, save=False, plot_path="train_viz"): | |
pos1_aux = pos1.cpu().numpy() | |
pos2_aux = pos2.cpu().numpy() | |
k = pos1_aux.shape[1] | |
col = np.random.rand(k, 3) | |
n_sp = 4 | |
plt.figure() | |
plt.subplot(1, n_sp, 1) | |
im1 = imshow_image( | |
image1[0].cpu().numpy(), | |
preprocessing=batch['preprocessing'] | |
) | |
plt.imshow(im1) | |
plt.scatter( | |
pos1_aux[1, :], pos1_aux[0, :], | |
s=0.25**2, c=col, marker=',', alpha=0.5 | |
) | |
plt.axis('off') | |
plt.subplot(1, n_sp, 2) | |
plt.imshow( | |
output['scores1'][idx_in_batch].data.cpu().numpy(), | |
cmap='Reds' | |
) | |
plt.axis('off') | |
plt.subplot(1, n_sp, 3) | |
im2 = imshow_image( | |
image2[0].cpu().numpy(), | |
preprocessing=batch['preprocessing'] | |
) | |
plt.imshow(im2) | |
plt.scatter( | |
pos2_aux[1, :], pos2_aux[0, :], | |
s=0.25**2, c=col, marker=',', alpha=0.5 | |
) | |
plt.axis('off') | |
plt.subplot(1, n_sp, 4) | |
plt.imshow( | |
output['scores2'][idx_in_batch].data.cpu().numpy(), | |
cmap='Reds' | |
) | |
plt.axis('off') | |
if(save == True): | |
savefig(plot_path+'/%s.%02d.%02d.%d.png' % ( | |
'train' if batch['train'] else 'valid', | |
batch['epoch_idx'], | |
batch['batch_idx'] // batch['log_interval'], | |
idx_in_batch | |
), dpi=300) | |
else: | |
plt.show() | |
plt.close() | |
im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB) | |
im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB) | |
for i in range(0, pos1_aux.shape[1], 5): | |
im1 = cv2.circle(im1, (pos1_aux[1, i], pos1_aux[0, i]), 1, (0, 0, 255), 2) | |
for i in range(0, pos2_aux.shape[1], 5): | |
im2 = cv2.circle(im2, (pos2_aux[1, i], pos2_aux[0, i]), 1, (0, 0, 255), 2) | |
im3 = cv2.hconcat([im1, im2]) | |
for i in range(0, pos1_aux.shape[1], 5): | |
im3 = cv2.line(im3, (int(pos1_aux[1, i]), int(pos1_aux[0, i])), (int(pos2_aux[1, i]) + im1.shape[1], int(pos2_aux[0, i])), (0, 255, 0), 1) | |
if(save == True): | |
cv2.imwrite(plot_path+'/%s.%02d.%02d.%d.png' % ( | |
'train_corr' if batch['train'] else 'valid', | |
batch['epoch_idx'], | |
batch['batch_idx'] // batch['log_interval'], | |
idx_in_batch | |
), im3) | |
else: | |
cv2.imshow('Image', im3) | |
cv2.waitKey(0) |