File size: 709 Bytes
9de012e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch import nn
from leo.utils import get_mlp_head


class SequentialGroundHead(nn.Module):
    def __init__(self, hidden_size=4096):
        super().__init__()        
        # grounding head
        self.og3d_head = get_mlp_head(
            hidden_size * 2, hidden_size // 2,
            1, dropout=0.1
        )

    def forward(self, obj_embeds, grd_embdes, obj_masks=None):
        txt_embeds = grd_embdes
        og3d_logits = self.og3d_head(torch.cat((obj_embeds, txt_embeds.repeat(1, obj_embeds.shape[1], 1)), dim=2)).squeeze(2)
        if obj_masks is not None:
            og3d_logits = og3d_logits.masked_fill_(obj_masks.logical_not(), -float('inf')) 
        return og3d_logits