|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.init import trunc_normal_
|
|
from torch.nn.utils import weight_norm
|
|
|
|
|
|
class DINOHead(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_dim,
|
|
out_dim,
|
|
use_bn=False,
|
|
nlayers=3,
|
|
hidden_dim=2048,
|
|
bottleneck_dim=256,
|
|
mlp_bias=True,
|
|
):
|
|
super().__init__()
|
|
nlayers = max(nlayers, 1)
|
|
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
|
self.apply(self._init_weights)
|
|
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
|
self.last_layer.weight_g.data.fill_(1)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight, std=0.02)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.mlp(x)
|
|
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
|
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
|
x = self.last_layer(x)
|
|
return x
|
|
|
|
|
|
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
|
if nlayers == 1:
|
|
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
|
else:
|
|
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
|
if use_bn:
|
|
layers.append(nn.BatchNorm1d(hidden_dim))
|
|
layers.append(nn.GELU())
|
|
for _ in range(nlayers - 2):
|
|
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
|
if use_bn:
|
|
layers.append(nn.BatchNorm1d(hidden_dim))
|
|
layers.append(nn.GELU())
|
|
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
|
return nn.Sequential(*layers)
|
|
|