Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import urllib.request | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torchvision.transforms as tfm | |
from .. import logger | |
from ..utils.base_model import BaseModel | |
mast3r_path = Path(__file__).parent / "../../third_party/mast3r" | |
sys.path.append(str(mast3r_path)) | |
dust3r_path = Path(__file__).parent / "../../third_party/dust3r" | |
sys.path.append(str(dust3r_path)) | |
from mast3r.model import AsymmetricMASt3R | |
from mast3r.fast_nn import fast_reciprocal_NNs | |
from dust3r.image_pairs import make_pairs | |
from dust3r.inference import inference | |
from dust3r.utils.image import load_images | |
from hloc.matchers.duster import Duster | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Mast3r(Duster): | |
default_conf = { | |
"name": "Mast3r", | |
"model_path": mast3r_path | |
/ "model_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", | |
"max_keypoints": 2000, | |
"vit_patch_size": 16, | |
} | |
def _init(self, conf): | |
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
self.model_path = self.conf["model_path"] | |
self.download_weights() | |
self.net = AsymmetricMASt3R.from_pretrained(self.model_path).to(device) | |
logger.info("Loaded Mast3r model") | |
def download_weights(self): | |
url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" | |
self.model_path.parent.mkdir(parents=True, exist_ok=True) | |
if not os.path.isfile(self.model_path): | |
logger.info("Downloading Mast3r(ViT large)... (takes a while)") | |
urllib.request.urlretrieve(url, self.model_path) | |
logger.info("Downloading Mast3r(ViT large)... done!") | |
def _forward(self, data): | |
img0, img1 = data["image0"], data["image1"] | |
mean = torch.tensor([0.5, 0.5, 0.5]).to(device) | |
std = torch.tensor([0.5, 0.5, 0.5]).to(device) | |
img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) | |
img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1) | |
images = [ | |
{"img": img0, "idx": 0, "instance": 0}, | |
{"img": img1, "idx": 1, "instance": 1}, | |
] | |
pairs = make_pairs( | |
images, scene_graph="complete", prefilter=None, symmetrize=True | |
) | |
output = inference(pairs, self.net, device, batch_size=1) | |
# at this stage, you have the raw dust3r predictions | |
view1, pred1 = output["view1"], output["pred1"] | |
view2, pred2 = output["view2"], output["pred2"] | |
desc1, desc2 = ( | |
pred1["desc"][1].squeeze(0).detach(), | |
pred2["desc"][1].squeeze(0).detach(), | |
) | |
# find 2D-2D matches between the two images | |
matches_im0, matches_im1 = fast_reciprocal_NNs( | |
desc1, | |
desc2, | |
subsample_or_initxy1=2, | |
device=device, | |
dist="dot", | |
block_size=2**13, | |
) | |
mkpts0 = matches_im0.copy() | |
mkpts1 = matches_im1.copy() | |
if len(mkpts0) == 0: | |
pred = { | |
"keypoints0": torch.zeros([0, 2]), | |
"keypoints1": torch.zeros([0, 2]), | |
} | |
logger.warning(f"Matched {0} points") | |
else: | |
top_k = self.conf["max_keypoints"] | |
if top_k is not None and len(mkpts0) > top_k: | |
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype( | |
int | |
) | |
mkpts0 = mkpts0[keep] | |
mkpts1 = mkpts1[keep] | |
pred = { | |
"keypoints0": torch.from_numpy(mkpts0), | |
"keypoints1": torch.from_numpy(mkpts1), | |
} | |
return pred | |