Spaces:
Sleeping
Sleeping
File size: 4,707 Bytes
4d4dd90 8320ccc 4d4dd90 8320ccc 4d4dd90 8320ccc 4d4dd90 8320ccc 4d4dd90 8320ccc 4d4dd90 2eaeef9 8320ccc 4d4dd90 2eaeef9 4d4dd90 2eaeef9 4d4dd90 2eaeef9 4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import sys
import urllib.request
from pathlib import Path
import numpy as np
import torch
import torchvision.transforms as tfm
from .. import logger
from ..utils.base_model import BaseModel
duster_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(duster_path))
from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Duster(BaseModel):
default_conf = {
"name": "Duster3r",
"model_path": duster_path / "model_weights/duster_vit_large.pth",
"max_keypoints": 3000,
"vit_patch_size": 16,
}
def _init(self, conf):
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.model_path = self.conf["model_path"]
self.download_weights()
# self.net = load_model(self.model_path, device)
self.net = AsymmetricCroCo3DStereo.from_pretrained(
self.model_path
# "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
).to(device)
logger.info("Loaded Dust3r model")
def download_weights(self):
url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
self.model_path.parent.mkdir(parents=True, exist_ok=True)
if not os.path.isfile(self.model_path):
logger.info("Downloading Duster(ViT large)... (takes a while)")
urllib.request.urlretrieve(url, self.model_path)
def preprocess(self, img):
# the super-class already makes sure that img0,img1 have
# same resolution and that h == w
_, h, _ = img.shape
imsize = h
if not ((h % self.vit_patch_size) == 0):
imsize = int(
self.vit_patch_size * round(h / self.vit_patch_size, 0)
)
img = tfm.functional.resize(img, imsize, antialias=True)
_, new_h, new_w = img.shape
if not ((new_w % self.vit_patch_size) == 0):
safe_w = int(
self.vit_patch_size * round(new_w / self.vit_patch_size, 0)
)
img = tfm.functional.resize(img, (new_h, safe_w), antialias=True)
img = self.normalize(img).unsqueeze(0)
return img
def _forward(self, data):
img0, img1 = data["image0"], data["image1"]
mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).to(device)
img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
images = [
{"img": img0, "idx": 0, "instance": 0},
{"img": img1, "idx": 1, "instance": 1},
]
pairs = make_pairs(
images, scene_graph="complete", prefilter=None, symmetrize=True
)
output = inference(pairs, self.net, device, batch_size=1)
scene = global_aligner(
output, device=device, mode=GlobalAlignerMode.PairViewer
)
# retrieve useful values from scene:
imgs = scene.imgs
confidence_masks = scene.get_masks()
pts3d = scene.get_pts3d()
pts2d_list, pts3d_list = [], []
for i in range(2):
conf_i = confidence_masks[i].cpu().numpy()
pts2d_list.append(
xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
) # imgs[i].shape[:2] = (H, W)
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
if len(pts3d_list[1]) == 0:
pred = {
"keypoints0": torch.zeros([0, 2]),
"keypoints1": torch.zeros([0, 2]),
}
logger.warning(f"Matched {0} points")
else:
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
*pts3d_list
)
logger.info(f"Found {num_matches} matches")
mkpts1 = pts2d_list[1][reciprocal_in_P2]
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(mkpts0) > top_k:
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(
int
)
mkpts0 = mkpts0[keep]
mkpts1 = mkpts1[keep]
pred = {
"keypoints0": torch.from_numpy(mkpts0),
"keypoints1": torch.from_numpy(mkpts1),
}
return pred
|