yocabon's picture
add initial version of mast3r sfm and glomap/colmap wrapper
35e2575
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Whitener and RetrievalModel
# --------------------------------------------------------
import numpy as np
from tqdm import tqdm
import time
import torch
import torch.nn as nn
import mast3r.utils.path_to_dust3r # noqa
from dust3r.utils.image import load_images
default_device = torch.device('cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu')
# from https://github.com/gtolias/how/blob/4d73c88e0ffb55506e2ce6249e2a015ef6ccf79f/how/utils/whitening.py#L20
def pcawhitenlearn_shrinkage(X, s=1.0):
"""Learn PCA whitening with shrinkage from given descriptors"""
N = X.shape[0]
# Learning PCA w/o annotations
m = X.mean(axis=0, keepdims=True)
Xc = X - m
Xcov = np.dot(Xc.T, Xc)
Xcov = (Xcov + Xcov.T) / (2 * N)
eigval, eigvec = np.linalg.eig(Xcov)
order = eigval.argsort()[::-1]
eigval = eigval[order]
eigvec = eigvec[:, order]
eigval = np.clip(eigval, a_min=1e-14, a_max=None)
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5 * s))), eigvec.T)
return m, P.T
class Dust3rInputFromImageList(torch.utils.data.Dataset):
def __init__(self, image_list, imsize=512):
super().__init__()
self.image_list = image_list
assert imsize == 512
self.imsize = imsize
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
return load_images([self.image_list[index]], size=self.imsize, verbose=False)[0]
class Whitener(nn.Module):
def __init__(self, dim, l2norm=None):
super().__init__()
self.m = torch.nn.Parameter(torch.zeros((1, dim)).double())
self.p = torch.nn.Parameter(torch.eye(dim, dim).double())
self.l2norm = l2norm # if not None, apply l2 norm along a given dimension
def forward(self, x):
with torch.autocast(self.m.device.type, enabled=False):
shape = x.size()
input_type = x.dtype
x_reshaped = x.view(-1, shape[-1]).to(dtype=self.m.dtype)
# Center the input data
x_centered = x_reshaped - self.m
# Apply PCA transformation
pca_output = torch.matmul(x_centered, self.p)
# reshape back
pca_output_shape = shape # list(shape[:-1]) + [shape[-1]]
pca_output = pca_output.view(pca_output_shape)
if self.l2norm is not None:
return torch.nn.functional.normalize(pca_output, dim=self.l2norm).to(dtype=input_type)
return pca_output.to(dtype=input_type)
def weighted_spoc(feat, attn):
"""
feat: BxNxC
attn: BxN
output: BxC L2-normalization weighted-sum-pooling of features
"""
return torch.nn.functional.normalize((feat * attn[:, :, None]).sum(dim=1), dim=1)
def how_select_local(feat, attn, nfeat):
"""
feat: BxNxC
attn: BxN
nfeat: nfeat to keep
"""
# get nfeat
if nfeat < 0:
assert nfeat >= -1.0
nfeat = int(-nfeat * feat.size(1))
else:
nfeat = int(nfeat)
# asort
topk_attn, topk_indices = torch.topk(attn, min(nfeat, attn.size(1)), dim=1)
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, feat.size(2))
topk_features = torch.gather(feat, 1, topk_indices_expanded)
return topk_features, topk_attn, topk_indices
class RetrievalModel(nn.Module):
def __init__(self, backbone, freeze_backbone=1, prewhiten=None, hdims=[1024], residual=False, postwhiten=None,
featweights='l2norm', nfeat=300, pretrained_retrieval=None):
super().__init__()
self.backbone = backbone
self.freeze_backbone = freeze_backbone
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone_dim = backbone.enc_embed_dim
self.prewhiten = nn.Identity() if prewhiten is None else Whitener(self.backbone_dim)
self.prewhiten_freq = prewhiten
if prewhiten is not None and prewhiten != -1:
for p in self.prewhiten.parameters():
p.requires_grad = False
self.residual = residual
self.projector = self.build_projector(hdims, residual)
self.dim = hdims[-1] if len(hdims) > 0 else self.backbone_dim
self.postwhiten_freq = postwhiten
self.postwhiten = nn.Identity() if postwhiten is None else Whitener(self.dim)
if postwhiten is not None and postwhiten != -1:
assert len(hdims) > 0
for p in self.postwhiten.parameters():
p.requires_grad = False
self.featweights = featweights
if featweights == 'l2norm':
self.attention = lambda x: x.norm(dim=-1)
else:
raise NotImplementedError(featweights)
self.nfeat = nfeat
self.pretrained_retrieval = pretrained_retrieval
if self.pretrained_retrieval is not None:
ckpt = torch.load(pretrained_retrieval, 'cpu')
msg = self.load_state_dict(ckpt['model'], strict=False)
assert len(msg.unexpected_keys) == 0 and all(k.startswith('backbone')
or k.startswith('postwhiten') for k in msg.missing_keys)
def build_projector(self, hdims, residual):
if self.residual:
assert hdims[-1] == self.backbone_dim
d = self.backbone_dim
if len(hdims) == 0:
return nn.Identity()
layers = []
for i in range(len(hdims) - 1):
layers.append(nn.Linear(d, hdims[i]))
d = hdims[i]
layers.append(nn.LayerNorm(d))
layers.append(nn.GELU())
layers.append(nn.Linear(d, hdims[-1]))
return nn.Sequential(*layers)
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
ss = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
if self.freeze_backbone:
ss = {k: v for k, v in ss.items() if not k.startswith('backbone')}
return ss
def reinitialize_whitening(self, epoch, train_dataset, nimgs=5000, log_writer=None, max_nfeat_per_image=None, seed=0, device=default_device):
do_prewhiten = self.prewhiten_freq is not None and self.pretrained_retrieval is None and \
(epoch == 0 or (self.prewhiten_freq > 0 and epoch % self.prewhiten_freq == 0))
do_postwhiten = self.postwhiten_freq is not None and ((epoch == 0 and self.postwhiten_freq in [0, -1])
or (self.postwhiten_freq > 0 and
epoch % self.postwhiten_freq == 0 and epoch > 0))
if do_prewhiten or do_postwhiten:
self.eval()
imdataset = train_dataset.imlist_dataset_n_images(nimgs, seed)
loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
if do_prewhiten:
print('Re-initialization of pre-whitening')
t = time.time()
with torch.no_grad():
features = []
for d in tqdm(loader):
feat = self.backbone._encode_image(d['img'][0, ...].to(device),
true_shape=d['true_shape'][0, ...])[0]
feat = feat.flatten(0, 1)
if max_nfeat_per_image is not None and max_nfeat_per_image < feat.size(0):
l2norms = torch.linalg.vector_norm(feat, dim=1)
feat = feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
features.append(feat.cpu())
features = torch.cat(features, dim=0)
features = features.numpy()
m, P = pcawhitenlearn_shrinkage(features)
self.prewhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
prewhiten_time = time.time() - t
print(f'Done in {prewhiten_time:.1f} seconds')
if log_writer is not None:
log_writer.add_scalar('time/prewhiten', prewhiten_time, epoch)
if do_postwhiten:
print(f'Re-initialization of post-whitening')
t = time.time()
with torch.no_grad():
features = []
for d in tqdm(loader):
backbone_feat = self.backbone._encode_image(d['img'][0, ...].to(device),
true_shape=d['true_shape'][0, ...])[0]
backbone_feat_prewhitened = self.prewhiten(backbone_feat)
proj_feat = self.projector(backbone_feat_prewhitened) + \
(0.0 if not self.residual else backbone_feat_prewhitened)
proj_feat = proj_feat.flatten(0, 1)
if max_nfeat_per_image is not None and max_nfeat_per_image < proj_feat.size(0):
l2norms = torch.linalg.vector_norm(proj_feat, dim=1)
proj_feat = proj_feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
features.append(proj_feat.cpu())
features = torch.cat(features, dim=0)
features = features.numpy()
m, P = pcawhitenlearn_shrinkage(features)
self.postwhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
postwhiten_time = time.time() - t
print(f'Done in {postwhiten_time:.1f} seconds')
if log_writer is not None:
log_writer.add_scalar('time/postwhiten', postwhiten_time, epoch)
def extract_features_and_attention(self, x):
backbone_feat = self.backbone._encode_image(x['img'], true_shape=x['true_shape'])[0]
backbone_feat_prewhitened = self.prewhiten(backbone_feat)
proj_feat = self.projector(backbone_feat_prewhitened) + \
(0.0 if not self.residual else backbone_feat_prewhitened)
attention = self.attention(proj_feat)
proj_feat_whitened = self.postwhiten(proj_feat)
return proj_feat_whitened, attention
def forward_local(self, x):
feat, attn = self.extract_features_and_attention(x)
return how_select_local(feat, attn, self.nfeat)
def forward_global(self, x):
feat, attn = self.extract_features_and_attention(x)
return weighted_spoc(feat, attn)
def forward(self, x):
return self.forward_global(x)
def identity(x): # to avoid Can't pickle local object 'extract_local_features.<locals>.<lambda>'
return x
@torch.no_grad()
def extract_local_features(model, images, imsize, seed=0, tocpu=False, max_nfeat_per_image=None,
max_nfeat_per_image2=None, device=default_device):
model.eval()
imdataset = Dust3rInputFromImageList(images, imsize=imsize) if isinstance(images, list) else images
loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False,
num_workers=8, pin_memory=True, collate_fn=identity)
with torch.no_grad():
features = []
imids = []
for i, d in enumerate(tqdm(loader)):
dd = d[0]
dd['img'] = dd['img'].to(device, non_blocking=True)
feat, _, _ = model.forward_local(dd)
feat = feat.flatten(0, 1)
if max_nfeat_per_image is not None and feat.size(0) > max_nfeat_per_image:
feat = feat[torch.randperm(feat.size(0))[:max_nfeat_per_image], :]
if max_nfeat_per_image2 is not None and feat.size(0) > max_nfeat_per_image2:
feat = feat[:max_nfeat_per_image2, :]
features.append(feat)
if tocpu:
features[-1] = features[-1].cpu()
imids.append(i * torch.ones_like(features[-1][:, 0]).to(dtype=torch.int64))
features = torch.cat(features, dim=0)
imids = torch.cat(imids, dim=0)
return features, imids