Spaces:
Sleeping
Sleeping
import torch | |
from ..utils.base_model import BaseModel | |
import numpy as np | |
# borrow from dedode | |
def dual_softmax_matcher( | |
desc_A: tuple["B", "C", "N"], | |
desc_B: tuple["B", "C", "M"], | |
threshold=0.1, | |
inv_temperature=20, | |
normalize=True, | |
): | |
B, C, N = desc_A.shape | |
if len(desc_A.shape) < 3: | |
desc_A, desc_B = desc_A[None], desc_B[None] | |
if normalize: | |
desc_A = desc_A / desc_A.norm(dim=1, keepdim=True) | |
desc_B = desc_B / desc_B.norm(dim=1, keepdim=True) | |
sim = torch.einsum("b c n, b c m -> b n m", desc_A, desc_B) * inv_temperature | |
P = sim.softmax(dim=-2) * sim.softmax(dim=-1) | |
mask = torch.nonzero( | |
(P == P.max(dim=-1, keepdim=True).values) | |
* (P == P.max(dim=-2, keepdim=True).values) | |
* (P > threshold) | |
) | |
mask = mask.cpu().numpy() | |
matches0 = np.ones((B, P.shape[-2]), dtype=int) * (-1) | |
scores0 = np.zeros((B, P.shape[-2]), dtype=float) | |
matches0[:, mask[:, 1]] = mask[:, 2] | |
tmp_P = P.cpu().numpy() | |
scores0[:, mask[:, 1]] = tmp_P[mask[:, 0], mask[:, 1], mask[:, 2]] | |
matches0 = torch.from_numpy(matches0).to(P.device) | |
scores0 = torch.from_numpy(scores0).to(P.device) | |
return matches0, scores0 | |
class DualSoftMax(BaseModel): | |
default_conf = { | |
"match_threshold": 0.2, | |
"inv_temperature": 20, | |
} | |
# shape: B x DIM x M | |
required_inputs = ["descriptors0", "descriptors1"] | |
def _init(self, conf): | |
pass | |
def _forward(self, data): | |
if data["descriptors0"].size(-1) == 0 or data["descriptors1"].size(-1) == 0: | |
matches0 = torch.full( | |
data["descriptors0"].shape[:2], -1, device=data["descriptors0"].device | |
) | |
return { | |
"matches0": matches0, | |
"matching_scores0": torch.zeros_like(matches0), | |
} | |
matches0, scores0 = dual_softmax_matcher( | |
data["descriptors0"], | |
data["descriptors1"], | |
threshold=self.conf["match_threshold"], | |
inv_temperature=self.conf["inv_temperature"], | |
) | |
return { | |
"matches0": matches0, # 1 x M | |
"matching_scores0": scores0, | |
} | |