|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def log_binom(n, k, eps=1e-7):
|
|
""" log(nCk) using stirling approximation """
|
|
n = n + eps
|
|
k = k + eps
|
|
return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
|
|
|
|
|
|
class LogBinomial(nn.Module):
|
|
def __init__(self, n_classes=256, act=torch.softmax):
|
|
"""Compute log binomial distribution for n_classes
|
|
|
|
Args:
|
|
n_classes (int, optional): number of output classes. Defaults to 256.
|
|
"""
|
|
super().__init__()
|
|
self.K = n_classes
|
|
self.act = act
|
|
self.register_buffer('k_idx', torch.arange(
|
|
0, n_classes).view(1, -1, 1, 1))
|
|
self.register_buffer('K_minus_1', torch.Tensor(
|
|
[self.K-1]).view(1, -1, 1, 1))
|
|
|
|
def forward(self, x, t=1., eps=1e-4):
|
|
"""Compute log binomial distribution for x
|
|
|
|
Args:
|
|
x (torch.Tensor - NCHW): probabilities
|
|
t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
|
|
eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
|
|
|
|
Returns:
|
|
torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
|
|
"""
|
|
if x.ndim == 3:
|
|
x = x.unsqueeze(1)
|
|
|
|
one_minus_x = torch.clamp(1 - x, eps, 1)
|
|
x = torch.clamp(x, eps, 1)
|
|
y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
|
|
torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
|
|
return self.act(y/t, dim=1)
|
|
|
|
|
|
class ConditionalLogBinomial(nn.Module):
|
|
def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
|
|
"""Conditional Log Binomial distribution
|
|
|
|
Args:
|
|
in_features (int): number of input channels in main feature
|
|
condition_dim (int): number of input channels in condition feature
|
|
n_classes (int, optional): Number of classes. Defaults to 256.
|
|
bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
|
|
p_eps (float, optional): small eps value. Defaults to 1e-4.
|
|
max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
|
|
min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
|
|
"""
|
|
super().__init__()
|
|
self.p_eps = p_eps
|
|
self.max_temp = max_temp
|
|
self.min_temp = min_temp
|
|
self.log_binomial_transform = LogBinomial(n_classes, act=act)
|
|
bottleneck = (in_features + condition_dim) // bottleneck_factor
|
|
self.mlp = nn.Sequential(
|
|
nn.Conv2d(in_features + condition_dim, bottleneck,
|
|
kernel_size=1, stride=1, padding=0),
|
|
nn.GELU(),
|
|
|
|
nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
|
|
nn.Softplus()
|
|
)
|
|
|
|
def forward(self, x, cond):
|
|
"""Forward pass
|
|
|
|
Args:
|
|
x (torch.Tensor - NCHW): Main feature
|
|
cond (torch.Tensor - NCHW): condition feature
|
|
|
|
Returns:
|
|
torch.Tensor: Output log binomial distribution
|
|
"""
|
|
pt = self.mlp(torch.concat((x, cond), dim=1))
|
|
p, t = pt[:, :2, ...], pt[:, 2:, ...]
|
|
|
|
p = p + self.p_eps
|
|
p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
|
|
|
|
t = t + self.p_eps
|
|
t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
|
|
t = t.unsqueeze(1)
|
|
t = (self.max_temp - self.min_temp) * t + self.min_temp
|
|
|
|
return self.log_binomial_transform(p, t)
|
|
|