Spaces:
Runtime error
Runtime error
File size: 709 Bytes
9de012e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from torch import nn
from leo.utils import get_mlp_head
class SequentialGroundHead(nn.Module):
def __init__(self, hidden_size=4096):
super().__init__()
# grounding head
self.og3d_head = get_mlp_head(
hidden_size * 2, hidden_size // 2,
1, dropout=0.1
)
def forward(self, obj_embeds, grd_embdes, obj_masks=None):
txt_embeds = grd_embdes
og3d_logits = self.og3d_head(torch.cat((obj_embeds, txt_embeds.repeat(1, obj_embeds.shape[1], 1)), dim=2)).squeeze(2)
if obj_masks is not None:
og3d_logits = og3d_logits.masked_fill_(obj_masks.logical_not(), -float('inf'))
return og3d_logits |