Detic / detic /modeling /roi_heads /zero_shot_classifier.py
AK391
files
159f437
raw
history blame
3.08 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Linear, ShapeSpec
class ZeroShotClassifier(nn.Module):
@configurable
def __init__(
self,
input_shape: ShapeSpec,
*,
num_classes: int,
zs_weight_path: str,
zs_weight_dim: int = 512,
use_bias: float = 0.0,
norm_weight: bool = True,
norm_temperature: float = 50.0,
):
super().__init__()
if isinstance(input_shape, int): # some backward compatibility
input_shape = ShapeSpec(channels=input_shape)
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
self.norm_weight = norm_weight
self.norm_temperature = norm_temperature
self.use_bias = use_bias < 0
if self.use_bias:
self.cls_bias = nn.Parameter(torch.ones(1) * use_bias)
self.linear = nn.Linear(input_size, zs_weight_dim)
if zs_weight_path == 'rand':
zs_weight = torch.randn((zs_weight_dim, num_classes))
nn.init.normal_(zs_weight, std=0.01)
else:
zs_weight = torch.tensor(
np.load(zs_weight_path),
dtype=torch.float32).permute(1, 0).contiguous() # D x C
zs_weight = torch.cat(
[zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))],
dim=1) # D x (C + 1)
if self.norm_weight:
zs_weight = F.normalize(zs_weight, p=2, dim=0)
if zs_weight_path == 'rand':
self.zs_weight = nn.Parameter(zs_weight)
else:
self.register_buffer('zs_weight', zs_weight)
assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape
@classmethod
def from_config(cls, cfg, input_shape):
return {
'input_shape': input_shape,
'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES,
'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH,
'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM,
'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS,
'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT,
'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP,
}
def forward(self, x, classifier=None):
'''
Inputs:
x: B x D'
classifier_info: (C', C' x D)
'''
x = self.linear(x)
if classifier is not None:
zs_weight = classifier.permute(1, 0).contiguous() # D x C'
zs_weight = F.normalize(zs_weight, p=2, dim=0) \
if self.norm_weight else zs_weight
else:
zs_weight = self.zs_weight
if self.norm_weight:
x = self.norm_temperature * F.normalize(x, p=2, dim=1)
x = torch.mm(x, zs_weight)
if self.use_bias:
x = x + self.cls_bias
return x