PhyscalX's picture
TAP v1.1 models release
4cee877
# ------------------------------------------------------------------------
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Prompt encoder."""
import torch
from torch import nn
class PromptEncoder(nn.Module):
"""Module to encode geometric prompts."""
def __init__(self, embed_dim, image_size):
super(PromptEncoder, self).__init__()
self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
self.img_pos, self.img_size = None, [image_size] * 2
def as_tensor(self, input):
"""Convert input into a tensor."""
return torch.as_tensor(input, device=self.coord_matrix.device)
def to_points(self, points=None, boxes=None):
"""Convert points or boxes to point prompts."""
if points is not None:
if isinstance(points, (tuple, list)):
coords, labels = points
else:
coords, labels = points[:, :, :2], points[:, :, 2]
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
coords = self.as_tensor(coords.clip(0, 1).astype("float32"))
labels = self.as_tensor(labels.astype("int64"))
return coords, labels
if boxes is not None:
coords = boxes.reshape((-1, 2, 2))
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
coords = self.as_tensor(coords.clip(0, 1).astype("float32"))
labels = self.as_tensor(self.corner_labels)
return coords, labels
return None
def encode_coords(self, coords):
"""Return the embedding for given coords."""
pi4, pi2 = 4 * 3.1415926, 2 * 3.1415926
if self.coord_matrix.dtype != torch.float32:
self.coord_matrix = self.coord_matrix.float()
rad = coords.mul(pi4).sub_(pi2) @ self.coord_matrix
dtype = self.point_embed.weight.dtype
return torch.cat([rad.sin(), rad.cos()], dim=-1).to(dtype=dtype)
def encode_points(self, coords, labels):
"""Return the embedding for given points."""
embed = self.encode_coords(coords)
embed.mul_(labels.ne(4).unsqueeze_(-1).float().to(dtype=embed.dtype))
return embed.add_(self.point_embed(labels))
def encode_grid(self, grid_size):
"""Return the embedding for a grid of specified size."""
grid = torch.ones(*grid_size, dtype=torch.float32)
y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
coords = self.as_tensor(torch.stack([x, y], dim=-1))
return self.encode_coords(coords)
def forward(self, inputs):
sparse_embeds = []
if inputs.get("boxes", None) is not None:
coords, labels = self.to_points(boxes=inputs["boxes"])
sparse_embeds.append(self.encode_points(coords, labels))
if inputs.get("points", None) is not None:
coords, labels = self.to_points(points=inputs["points"])
sparse_embeds.append(self.encode_points(coords, labels))
if len(sparse_embeds) > 1:
sparse_embeds = [torch.cat(sparse_embeds, dim=1)]
elif len(sparse_embeds) == 0:
raise ValueError("Excepted ``points`` or ``boxes`` prompts.")
img_embed_size = torch.Size(inputs["img_embeds"].shape[2:-1])
if self.img_pos is None or self.img_pos.shape[0] != img_embed_size.numel():
self.img_pos = self.encode_grid(img_embed_size).flatten(0, 1)
return {"sparse_embeds": sparse_embeds[0], "img_pos": self.img_pos}