|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class SeedBinRegressor(nn.Module):
|
|
def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
|
|
"""Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
|
|
|
|
Args:
|
|
in_features (int): input channels
|
|
n_bins (int, optional): Number of bin centers. Defaults to 16.
|
|
mlp_dim (int, optional): Hidden dimension. Defaults to 256.
|
|
min_depth (float, optional): Min depth value. Defaults to 1e-3.
|
|
max_depth (float, optional): Max depth value. Defaults to 10.
|
|
"""
|
|
super().__init__()
|
|
self.version = "1_1"
|
|
self.min_depth = min_depth
|
|
self.max_depth = max_depth
|
|
|
|
self._net = nn.Sequential(
|
|
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Returns tensor of bin_width vectors (centers). One vector b for every pixel
|
|
"""
|
|
B = self._net(x)
|
|
eps = 1e-3
|
|
B = B + eps
|
|
B_widths_normed = B / B.sum(dim=1, keepdim=True)
|
|
B_widths = (self.max_depth - self.min_depth) * \
|
|
B_widths_normed
|
|
|
|
B_widths = nn.functional.pad(
|
|
B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
|
|
B_edges = torch.cumsum(B_widths, dim=1)
|
|
|
|
B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
|
|
return B_widths_normed, B_centers
|
|
|
|
|
|
class SeedBinRegressorUnnormed(nn.Module):
|
|
def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
|
|
"""Bin center regressor network. Bin centers are unbounded
|
|
|
|
Args:
|
|
in_features (int): input channels
|
|
n_bins (int, optional): Number of bin centers. Defaults to 16.
|
|
mlp_dim (int, optional): Hidden dimension. Defaults to 256.
|
|
min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
|
|
max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
|
|
"""
|
|
super().__init__()
|
|
self.version = "1_1"
|
|
self._net = nn.Sequential(
|
|
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
|
|
nn.Softplus()
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Returns tensor of bin_width vectors (centers). One vector b for every pixel
|
|
"""
|
|
B_centers = self._net(x)
|
|
return B_centers, B_centers
|
|
|
|
|
|
class Projector(nn.Module):
|
|
def __init__(self, in_features, out_features, mlp_dim=128):
|
|
"""Projector MLP
|
|
|
|
Args:
|
|
in_features (int): input channels
|
|
out_features (int): output channels
|
|
mlp_dim (int, optional): hidden dimension. Defaults to 128.
|
|
"""
|
|
super().__init__()
|
|
|
|
self._net = nn.Sequential(
|
|
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self._net(x)
|
|
|
|
|
|
|
|
class LinearSplitter(nn.Module):
|
|
def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
|
|
super().__init__()
|
|
|
|
self.prev_nbins = prev_nbins
|
|
self.split_factor = split_factor
|
|
self.min_depth = min_depth
|
|
self.max_depth = max_depth
|
|
|
|
self._net = nn.Sequential(
|
|
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
|
nn.GELU(),
|
|
nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
|
"""
|
|
x : feature block; shape - n, c, h, w
|
|
b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
|
|
"""
|
|
if prev_b_embedding is not None:
|
|
if interpolate:
|
|
prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
|
x = x + prev_b_embedding
|
|
S = self._net(x)
|
|
eps = 1e-3
|
|
S = S + eps
|
|
n, c, h, w = S.shape
|
|
S = S.view(n, self.prev_nbins, self.split_factor, h, w)
|
|
S_normed = S / S.sum(dim=2, keepdim=True)
|
|
|
|
b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
|
|
|
|
|
|
b_prev = b_prev / b_prev.sum(dim=1, keepdim=True)
|
|
|
|
|
|
b = b_prev.unsqueeze(2) * S_normed
|
|
b = b.flatten(1,2)
|
|
|
|
|
|
B_widths = (self.max_depth - self.min_depth) * b
|
|
|
|
B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
|
|
B_edges = torch.cumsum(B_widths, dim=1)
|
|
|
|
B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
|
|
return b, B_centers |