AniDoc / LightGlue /lightglue /lightglue.py
fffiloni's picture
Migrated from GitHub
c705408 verified
raw
history blame
26.1 kB
import warnings
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
try:
from flash_attn.modules.mha import FlashCrossAttention
except ModuleNotFoundError:
FlashCrossAttention = None
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
FLASH_AVAILABLE = True
else:
FLASH_AVAILABLE = False
torch.backends.cudnn.deterministic = True
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def normalize_keypoints(
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
) -> torch.Tensor:
if size is None:
size = 1 + kpts.max(-2).values - kpts.min(-2).values
elif not isinstance(size, torch.Tensor):
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
size = size.to(kpts)
shift = size / 2
scale = size.max(-1).values / 2
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
return kpts
def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
if length <= x.shape[-2]:
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
pad = torch.ones(
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
)
y = torch.cat([x, pad], dim=-2)
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
mask[..., : x.shape[-2], :] = True
return y, mask
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
class LearnableFourierPositionalEncoding(nn.Module):
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
super().__init__()
F_dim = F_dim if F_dim is not None else dim
self.gamma = gamma
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""encode position vector"""
projected = self.Wr(x)
cosines, sines = torch.cos(projected), torch.sin(projected)
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
return emb.repeat_interleave(2, dim=-1)
class TokenConfidence(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
"""get confidence tokens"""
return (
self.token(desc0.detach()).squeeze(-1),
self.token(desc1.detach()).squeeze(-1),
)
class Attention(nn.Module):
def __init__(self, allow_flash: bool) -> None:
super().__init__()
if allow_flash and not FLASH_AVAILABLE:
warnings.warn(
"FlashAttention is not available. For optimal speed, "
"consider installing torch >= 2.0 or flash-attn.",
stacklevel=2,
)
self.enable_flash = allow_flash and FLASH_AVAILABLE
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
if allow_flash and FlashCrossAttention:
self.flash_ = FlashCrossAttention()
if self.has_sdp:
torch.backends.cuda.enable_flash_sdp(allow_flash)
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if q.shape[-2] == 0 or k.shape[-2] == 0:
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
if self.enable_flash and q.device.type == "cuda":
# use torch 2.0 scaled_dot_product_attention with flash
if self.has_sdp:
args = [x.half().contiguous() for x in [q, k, v]]
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
return v if mask is None else v.nan_to_num()
else:
assert mask is None
q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
m = self.flash_(q.half(), torch.stack([k, v], 2).half())
return m.transpose(-2, -3).to(q.dtype).clone()
elif self.has_sdp:
args = [x.contiguous() for x in [q, k, v]]
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
return v if mask is None else v.nan_to_num()
else:
s = q.shape[-1] ** -0.5
sim = torch.einsum("...id,...jd->...ij", q, k) * s
if mask is not None:
sim.masked_fill(~mask, -float("inf"))
attn = F.softmax(sim, -1)
return torch.einsum("...ij,...jd->...id", attn, v)
class SelfBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0
self.head_dim = self.embed_dim // num_heads
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.inner_attn = Attention(flash)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.ffn = nn.Sequential(
nn.Linear(2 * embed_dim, 2 * embed_dim),
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
nn.GELU(),
nn.Linear(2 * embed_dim, embed_dim),
)
def forward(
self,
x: torch.Tensor,
encoding: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.Wqkv(x)
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
q = apply_cached_rotary_emb(encoding, q)
k = apply_cached_rotary_emb(encoding, k)
context = self.inner_attn(q, k, v, mask=mask)
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
return x + self.ffn(torch.cat([x, message], -1))
class CrossBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
) -> None:
super().__init__()
self.heads = num_heads
dim_head = embed_dim // num_heads
self.scale = dim_head**-0.5
inner_dim = dim_head * num_heads
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
self.ffn = nn.Sequential(
nn.Linear(2 * embed_dim, 2 * embed_dim),
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
nn.GELU(),
nn.Linear(2 * embed_dim, embed_dim),
)
if flash and FLASH_AVAILABLE:
self.flash = Attention(True)
else:
self.flash = None
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
return func(x0), func(x1)
def forward(
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> List[torch.Tensor]:
qk0, qk1 = self.map_(self.to_qk, x0, x1)
v0, v1 = self.map_(self.to_v, x0, x1)
qk0, qk1, v0, v1 = map(
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
(qk0, qk1, v0, v1),
)
if self.flash is not None and qk0.device.type == "cuda":
m0 = self.flash(qk0, qk1, v1, mask)
m1 = self.flash(
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
)
else:
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
if mask is not None:
sim = sim.masked_fill(~mask, -float("inf"))
attn01 = F.softmax(sim, dim=-1)
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
if mask is not None:
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
m0, m1 = self.map_(self.to_out, m0, m1)
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
return x0, x1
class TransformerLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.self_attn = SelfBlock(*args, **kwargs)
self.cross_attn = CrossBlock(*args, **kwargs)
def forward(
self,
desc0,
desc1,
encoding0,
encoding1,
mask0: Optional[torch.Tensor] = None,
mask1: Optional[torch.Tensor] = None,
):
if mask0 is not None and mask1 is not None:
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
else:
desc0 = self.self_attn(desc0, encoding0)
desc1 = self.self_attn(desc1, encoding1)
return self.cross_attn(desc0, desc1)
# This part is compiled and allows padding inputs
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
mask = mask0 & mask1.transpose(-1, -2)
mask0 = mask0 & mask0.transpose(-1, -2)
mask1 = mask1 & mask1.transpose(-1, -2)
desc0 = self.self_attn(desc0, encoding0, mask0)
desc1 = self.self_attn(desc1, encoding1, mask1)
return self.cross_attn(desc0, desc1, mask)
def sigmoid_log_double_softmax(
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
) -> torch.Tensor:
"""create the log assignment matrix from logits and similarity"""
b, m, n = sim.shape
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
scores0 = F.log_softmax(sim, 2)
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
scores = sim.new_full((b, m + 1, n + 1), 0)
scores[:, :m, :n] = scores0 + scores1 + certainties
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
return scores
class MatchAssignment(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
self.matchability = nn.Linear(dim, 1, bias=True)
self.final_proj = nn.Linear(dim, dim, bias=True)
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
"""build assignment matrix from descriptors"""
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
_, _, d = mdesc0.shape
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
z0 = self.matchability(desc0)
z1 = self.matchability(desc1)
scores = sigmoid_log_double_softmax(sim, z0, z1)
return scores, sim
def get_matchability(self, desc: torch.Tensor):
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
def filter_matches(scores: torch.Tensor, th: float):
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
m0, m1 = max0.indices, max1.indices
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
mutual0 = indices0 == m1.gather(1, m0)
mutual1 = indices1 == m0.gather(1, m1)
max0_exp = max0.values.exp()
zero = max0_exp.new_tensor(0)
mscores0 = torch.where(mutual0, max0_exp, zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
valid0 = mutual0 & (mscores0 > th)
valid1 = mutual1 & valid0.gather(1, m1)
m0 = torch.where(valid0, m0, -1)
m1 = torch.where(valid1, m1, -1)
return m0, m1, mscores0, mscores1
class LightGlue(nn.Module):
default_conf = {
"name": "lightglue", # just for interfacing
"input_dim": 256, # input descriptor dimension (autoselected from weights)
"descriptor_dim": 256,
"add_scale_ori": False,
"n_layers": 9,
"num_heads": 4,
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"depth_confidence": 0.95, # early stopping, disable with -1
"width_confidence": 0.99, # point pruning, disable with -1
"filter_threshold": 0.1, # match threshold
"weights": None,
}
# Point pruning involves an overhead (gather).
# Therefore, we only activate it if there are enough keypoints.
pruning_keypoint_thresholds = {
"cpu": -1,
"mps": -1,
"cuda": 1024,
"flash": 1536,
}
required_data_keys = ["image0", "image1"]
version = "v0.1_arxiv"
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
features = {
"superpoint": {
"weights": "superpoint_lightglue",
"input_dim": 256,
},
"disk": {
"weights": "disk_lightglue",
"input_dim": 128,
},
"aliked": {
"weights": "aliked_lightglue",
"input_dim": 128,
},
"sift": {
"weights": "sift_lightglue",
"input_dim": 128,
"add_scale_ori": True,
},
"doghardnet": {
"weights": "doghardnet_lightglue",
"input_dim": 128,
"add_scale_ori": True,
},
}
def __init__(self, features="superpoint", **conf) -> None:
super().__init__()
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
if features is not None:
if features not in self.features:
raise ValueError(
f"Unsupported features: {features} not in "
f"{{{','.join(self.features)}}}"
)
for k, v in self.features[features].items():
setattr(conf, k, v)
if conf.input_dim != conf.descriptor_dim:
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
else:
self.input_proj = nn.Identity()
head_dim = conf.descriptor_dim // conf.num_heads
self.posenc = LearnableFourierPositionalEncoding(
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
)
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
self.transformers = nn.ModuleList(
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
)
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
self.token_confidence = nn.ModuleList(
[TokenConfidence(d) for _ in range(n - 1)]
)
self.register_buffer(
"confidence_thresholds",
torch.Tensor(
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
),
)
state_dict = None
if features is not None:
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
state_dict = torch.hub.load_state_dict_from_url(
self.url.format(self.version, features), model_dir='./LightGlue/ckpts',file_name="superpoint_lightglue.pth"
)
self.load_state_dict(state_dict, strict=False)
elif conf.weights is not None:
path = Path(__file__).parent
path = path / "weights/{}.pth".format(self.conf.weights)
state_dict = torch.load(str(path), map_location="cpu")
if state_dict:
# rename old state dict entries
for i in range(self.conf.n_layers):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
self.load_state_dict(state_dict, strict=False)
# static lengths LightGlue is compiled for (only used with torch.compile)
self.static_lengths = None
def compile(
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
):
if self.conf.width_confidence != -1:
warnings.warn(
"Point pruning is partially disabled for compiled forward.",
stacklevel=2,
)
torch._inductor.cudagraph_mark_step_begin()
for i in range(self.conf.n_layers):
self.transformers[i].masked_forward = torch.compile(
self.transformers[i].masked_forward, mode=mode, fullgraph=True
)
self.static_lengths = static_lengths
def forward(self, data: dict) -> dict:
"""
Match keypoints and descriptors between two images
Input (dict):
image0: dict
keypoints: [B x M x 2]
descriptors: [B x M x D]
image: [B x C x H x W] or image_size: [B x 2]
image1: dict
keypoints: [B x N x 2]
descriptors: [B x N x D]
image: [B x C x H x W] or image_size: [B x 2]
Output (dict):
matches0: [B x M]
matching_scores0: [B x M]
matches1: [B x N]
matching_scores1: [B x N]
matches: List[[Si x 2]]
scores: List[[Si]]
stop: int
prune0: [B x M]
prune1: [B x N]
"""
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
return self._forward(data)
def _forward(self, data: dict) -> dict:
for key in self.required_data_keys:
assert key in data, f"Missing key {key} in data"
data0, data1 = data["image0"], data["image1"]
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
b, m, _ = kpts0.shape
b, n, _ = kpts1.shape
device = kpts0.device
size0, size1 = data0.get("image_size"), data1.get("image_size")
kpts0 = normalize_keypoints(kpts0, size0).clone()
kpts1 = normalize_keypoints(kpts1, size1).clone()
if self.conf.add_scale_ori:
kpts0 = torch.cat(
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
kpts1 = torch.cat(
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
desc0 = data0["descriptors"].detach().contiguous()
desc1 = data1["descriptors"].detach().contiguous()
assert desc0.shape[-1] == self.conf.input_dim
assert desc1.shape[-1] == self.conf.input_dim
if torch.is_autocast_enabled():
desc0 = desc0.half()
desc1 = desc1.half()
mask0, mask1 = None, None
c = max(m, n)
do_compile = self.static_lengths and c <= max(self.static_lengths)
if do_compile:
kn = min([k for k in self.static_lengths if k >= c])
desc0, mask0 = pad_to_length(desc0, kn)
desc1, mask1 = pad_to_length(desc1, kn)
kpts0, _ = pad_to_length(kpts0, kn)
kpts1, _ = pad_to_length(kpts1, kn)
desc0 = self.input_proj(desc0)
desc1 = self.input_proj(desc1)
# cache positional embeddings
encoding0 = self.posenc(kpts0)
encoding1 = self.posenc(kpts1)
# GNN + final_proj + assignment
do_early_stop = self.conf.depth_confidence > 0
do_point_pruning = self.conf.width_confidence > 0 and not do_compile
pruning_th = self.pruning_min_kpts(device)
if do_point_pruning:
ind0 = torch.arange(0, m, device=device)[None]
ind1 = torch.arange(0, n, device=device)[None]
# We store the index of the layer at which pruning is detected.
prune0 = torch.ones_like(ind0)
prune1 = torch.ones_like(ind1)
token0, token1 = None, None
for i in range(self.conf.n_layers):
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
break
desc0, desc1 = self.transformers[i](
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
)
if i == self.conf.n_layers - 1:
continue # no early stopping or adaptive width at last layer
if do_early_stop:
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
break
if do_point_pruning and desc0.shape[-2] > pruning_th:
scores0 = self.log_assignment[i].get_matchability(desc0)
prunemask0 = self.get_pruning_mask(token0, scores0, i)
keep0 = torch.where(prunemask0)[1]
ind0 = ind0.index_select(1, keep0)
desc0 = desc0.index_select(1, keep0)
encoding0 = encoding0.index_select(-2, keep0)
prune0[:, ind0] += 1
if do_point_pruning and desc1.shape[-2] > pruning_th:
scores1 = self.log_assignment[i].get_matchability(desc1)
prunemask1 = self.get_pruning_mask(token1, scores1, i)
keep1 = torch.where(prunemask1)[1]
ind1 = ind1.index_select(1, keep1)
desc1 = desc1.index_select(1, keep1)
encoding1 = encoding1.index_select(-2, keep1)
prune1[:, ind1] += 1
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
mscores0 = desc0.new_zeros((b, m))
mscores1 = desc1.new_zeros((b, n))
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
mscores = desc0.new_empty((b, 0))
if not do_point_pruning:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"stop": i + 1,
"matches": matches,
"scores": mscores,
"prune0": prune0,
"prune1": prune1,
}
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
matches, mscores = [], []
for k in range(b):
valid = m0[k] > -1
m_indices_0 = torch.where(valid)[0]
m_indices_1 = m0[k][valid]
if do_point_pruning:
m_indices_0 = ind0[k, m_indices_0]
m_indices_1 = ind1[k, m_indices_1]
matches.append(torch.stack([m_indices_0, m_indices_1], -1))
mscores.append(mscores0[k][valid])
# TODO: Remove when hloc switches to the compact format.
if do_point_pruning:
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
mscores0_ = torch.zeros((b, m), device=mscores0.device)
mscores1_ = torch.zeros((b, n), device=mscores1.device)
mscores0_[:, ind0] = mscores0
mscores1_[:, ind1] = mscores1
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
else:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"stop": i + 1,
"matches": matches,
"scores": mscores,
"prune0": prune0,
"prune1": prune1,
}
def confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
return np.clip(threshold, 0, 1)
def get_pruning_mask(
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
) -> torch.Tensor:
"""mask points which should be removed"""
keep = scores > (1 - self.conf.width_confidence)
if confidences is not None: # Low-confidence points are never pruned.
keep |= confidences <= self.confidence_thresholds[layer_index]
return keep
def check_if_stop(
self,
confidences0: torch.Tensor,
confidences1: torch.Tensor,
layer_index: int,
num_points: int,
) -> torch.Tensor:
"""evaluate stopping condition"""
confidences = torch.cat([confidences0, confidences1], -1)
threshold = self.confidence_thresholds[layer_index]
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
return ratio_confident > self.conf.depth_confidence
def pruning_min_kpts(self, device: torch.device):
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
return self.pruning_keypoint_thresholds["flash"]
else:
return self.pruning_keypoint_thresholds[device.type]