|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from timm.models.vision_transformer import Block |
|
import math |
|
|
|
import gazelle.utils as utils |
|
from gazelle.backbone import DinoV2Backbone |
|
|
|
|
|
class GazeLLE(nn.Module): |
|
def __init__(self, backbone, inout=False, dim=256, num_layers=3, in_size=(448, 448), out_size=(64, 64)): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.dim = dim |
|
self.num_layers = num_layers |
|
self.featmap_h, self.featmap_w = backbone.get_out_size(in_size) |
|
self.in_size = in_size |
|
self.out_size = out_size |
|
self.inout = inout |
|
|
|
self.linear = nn.Conv2d(backbone.get_dimension(), self.dim, 1) |
|
self.register_buffer("pos_embed", positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(dim=0)) |
|
self.transformer = nn.Sequential(*[ |
|
Block( |
|
dim=self.dim, |
|
num_heads=8, |
|
mlp_ratio=4, |
|
drop_path=0.1) |
|
for i in range(num_layers) |
|
]) |
|
self.heatmap_head = nn.Sequential( |
|
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), |
|
nn.Conv2d(dim, 1, kernel_size=1, bias=False), |
|
nn.Sigmoid() |
|
) |
|
self.head_token = nn.Embedding(1, self.dim) |
|
if self.inout: |
|
self.inout_head = nn.Sequential( |
|
nn.Linear(self.dim, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(128, 1), |
|
nn.Sigmoid() |
|
) |
|
self.inout_token = nn.Embedding(1, self.dim) |
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]] |
|
x = self.backbone.forward(input["images"]) |
|
x = self.linear(x) |
|
x = x + self.pos_embed |
|
x = utils.repeat_tensors(x, num_ppl_per_img) |
|
head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device) |
|
head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1) |
|
x = x + head_map_embeddings |
|
x = x.flatten(start_dim=2).permute(0, 2, 1) |
|
|
|
if self.inout: |
|
x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1) |
|
|
|
x = self.transformer(x) |
|
|
|
if self.inout: |
|
inout_tokens = x[:, 0, :] |
|
inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1) |
|
inout_preds = utils.split_tensors(inout_preds, num_ppl_per_img) |
|
x = x[:, 1:, :] |
|
|
|
x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2) |
|
x = self.heatmap_head(x).squeeze(dim=1) |
|
x = torchvision.transforms.functional.resize(x, self.out_size) |
|
heatmap_preds = utils.split_tensors(x, num_ppl_per_img) |
|
|
|
return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None} |
|
|
|
def get_input_head_maps(self, bboxes): |
|
|
|
head_maps = [] |
|
for bbox_list in bboxes: |
|
img_head_maps = [] |
|
for bbox in bbox_list: |
|
if bbox is None: |
|
img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w)) |
|
else: |
|
xmin, ymin, xmax, ymax = bbox |
|
width, height = self.featmap_w, self.featmap_h |
|
xmin = round(xmin * width) |
|
ymin = round(ymin * height) |
|
xmax = round(xmax * width) |
|
ymax = round(ymax * height) |
|
head_map = torch.zeros((height, width)) |
|
head_map[ymin:ymax, xmin:xmax] = 1 |
|
img_head_maps.append(head_map) |
|
head_maps.append(torch.stack(img_head_maps)) |
|
return head_maps |
|
|
|
def get_gazelle_state_dict(self, include_backbone=False): |
|
if include_backbone: |
|
return self.state_dict() |
|
else: |
|
return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")} |
|
|
|
def load_gazelle_state_dict(self, ckpt_state_dict, include_backbone=False): |
|
current_state_dict = self.state_dict() |
|
keys1 = current_state_dict.keys() |
|
keys2 = ckpt_state_dict.keys() |
|
|
|
if not include_backbone: |
|
keys1 = set([k for k in keys1 if not k.startswith("backbone")]) |
|
keys2 = set([k for k in keys2 if not k.startswith("backbone")]) |
|
else: |
|
keys1 = set(keys1) |
|
keys2 = set(keys2) |
|
|
|
if len(keys2 - keys1) > 0: |
|
print("WARNING unused keys in provided state dict: ", keys2 - keys1) |
|
if len(keys1 - keys2) > 0: |
|
print("WARNING provided state dict does not have values for keys: ", keys1 - keys2) |
|
|
|
for k in list(keys1 & keys2): |
|
current_state_dict[k] = ckpt_state_dict[k] |
|
|
|
self.load_state_dict(current_state_dict, strict=False) |
|
|
|
|
|
|
|
def positionalencoding2d(d_model, height, width): |
|
""" |
|
:param d_model: dimension of the model |
|
:param height: height of the positions |
|
:param width: width of the positions |
|
:return: d_model*height*width position matrix |
|
""" |
|
if d_model % 4 != 0: |
|
raise ValueError("Cannot use sin/cos positional encoding with " |
|
"odd dimension (got dim={:d})".format(d_model)) |
|
pe = torch.zeros(d_model, height, width) |
|
|
|
d_model = int(d_model / 2) |
|
div_term = torch.exp(torch.arange(0., d_model, 2) * |
|
-(math.log(10000.0) / d_model)) |
|
pos_w = torch.arange(0., width).unsqueeze(1) |
|
pos_h = torch.arange(0., height).unsqueeze(1) |
|
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) |
|
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) |
|
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) |
|
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) |
|
|
|
return pe |
|
|
|
|
|
|
|
def get_gazelle_model(model_name): |
|
factory = { |
|
"gazelle_dinov2_vitb14": gazelle_dinov2_vitb14, |
|
"gazelle_dinov2_vitl14": gazelle_dinov2_vitl14, |
|
"gazelle_dinov2_vitb14_inout": gazelle_dinov2_vitb14_inout, |
|
"gazelle_dinov2_vitl14_inout": gazelle_dinov2_vitl14_inout, |
|
} |
|
assert model_name in factory.keys(), "invalid model name" |
|
return factory[model_name]() |
|
|
|
def gazelle_dinov2_vitb14(): |
|
backbone = DinoV2Backbone('dinov2_vitb14') |
|
transform = backbone.get_transform((448, 448)) |
|
model = GazeLLE(backbone) |
|
return model, transform |
|
|
|
def gazelle_dinov2_vitl14(): |
|
backbone = DinoV2Backbone('dinov2_vitl14') |
|
transform = backbone.get_transform((448, 448)) |
|
model = GazeLLE(backbone) |
|
return model, transform |
|
|
|
def gazelle_dinov2_vitb14_inout(): |
|
backbone = DinoV2Backbone('dinov2_vitb14') |
|
transform = backbone.get_transform((448, 448)) |
|
model = GazeLLE(backbone, inout=True) |
|
return model, transform |
|
|
|
def gazelle_dinov2_vitl14_inout(): |
|
backbone = DinoV2Backbone('dinov2_vitl14') |
|
transform = backbone.get_transform((448, 448)) |
|
model = GazeLLE(backbone, inout=True) |
|
return model, transform |
|
|