import torch import einops import numpy as np import torch.nn.functional as F from torch import Tensor, nn from typing import Optional from leo.utils import get_activation_fn, layer_repeat, calc_pairwise_locs def disabled_train(self, mode=True): """ Overwrite model.train with this function to make sure train/eval mode does not change anymore """ return self class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, batch_first=True, dropout=0.1, activation="relu", prenorm=False): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first ) # Implementation of Feedforward modules self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = get_activation_fn(activation) self.prenorm = prenorm def forward( self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, ): tgt2 = tgt if self.prenorm: tgt2 = self.norm1(tgt2) tgt2, self_attn_matrices = self.self_attn( query=tgt2, key=tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask ) tgt = tgt + self.dropout1(tgt2) if not self.prenorm: tgt = self.norm1(tgt) if self.prenorm: tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout2(tgt2) if not self.prenorm: tgt = self.norm2(tgt) return tgt, self_attn_matrices class MultiHeadAttentionSpatial(nn.Module): def __init__( self, d_model, n_head, dropout=0.1, spatial_multihead=True, spatial_dim=5, spatial_attn_fusion='mul', ): super().__init__() assert d_model % n_head == 0, 'd_model: %d, n_head: %d' % (d_model, n_head) self.n_head = n_head self.d_model = d_model self.d_per_head = d_model // n_head self.spatial_multihead = spatial_multihead self.spatial_dim = spatial_dim self.spatial_attn_fusion = spatial_attn_fusion self.w_qs = nn.Linear(d_model, d_model) self.w_ks = nn.Linear(d_model, d_model) self.w_vs = nn.Linear(d_model, d_model) self.fc = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(p=dropout) self.layer_norm = nn.LayerNorm(d_model) self.spatial_n_head = n_head if spatial_multihead else 1 if self.spatial_attn_fusion in ['mul', 'bias', 'add']: self.pairwise_loc_fc = nn.Linear(spatial_dim, self.spatial_n_head) elif self.spatial_attn_fusion == 'ctx': self.pairwise_loc_fc = nn.Linear(spatial_dim, d_model) elif self.spatial_attn_fusion == 'cond': self.lang_cond_fc = nn.Linear(d_model, self.spatial_n_head * (spatial_dim + 1)) else: raise NotImplementedError('unsupported spatial_attn_fusion %s' % (self.spatial_attn_fusion)) def forward(self, q, k, v, pairwise_locs, key_padding_mask=None, txt_embeds=None): residual = q q = einops.rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head) k = einops.rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head) v = einops.rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head) attn = torch.einsum('hblk,hbtk->hblt', q, k) / np.sqrt(q.shape[-1]) if self.spatial_attn_fusion in ['mul', 'bias', 'add']: loc_attn = self.pairwise_loc_fc(pairwise_locs) loc_attn = einops.rearrange(loc_attn, 'b l t h -> h b l t') if self.spatial_attn_fusion == 'mul': loc_attn = F.relu(loc_attn) if not self.spatial_multihead: loc_attn = einops.repeat(loc_attn, 'h b l t -> (h nh) b l t', nh=self.n_head) elif self.spatial_attn_fusion == 'ctx': loc_attn = self.pairwise_loc_fc(pairwise_locs) loc_attn = einops.rearrange(loc_attn, 'b l t (h k) -> h b l t k', h=self.n_head) loc_attn = torch.einsum('hblk,hbltk->hblt', q, loc_attn) / np.sqrt(q.shape[-1]) elif self.spatial_attn_fusion == 'cond': spatial_weights = self.lang_cond_fc(residual) spatial_weights = einops.rearrange(spatial_weights, 'b l (h d) -> h b l d', h=self.spatial_n_head, d=self.spatial_dim + 1) if self.spatial_n_head == 1: spatial_weights = einops.repeat(spatial_weights, '1 b l d -> h b l d', h=self.n_head) spatial_bias = spatial_weights[..., :1] spatial_weights = spatial_weights[..., 1:] loc_attn = torch.einsum('hbld,bltd->hblt', spatial_weights, pairwise_locs) + spatial_bias loc_attn = torch.sigmoid(loc_attn) if key_padding_mask is not None: mask = einops.repeat(key_padding_mask, 'b t -> h b l t', h=self.n_head, l=q.size(2)) attn = attn.masked_fill(mask, -np.inf) if self.spatial_attn_fusion in ['mul', 'cond']: loc_attn = loc_attn.masked_fill(mask, 0) else: loc_attn = loc_attn.masked_fill(mask, -np.inf) if self.spatial_attn_fusion == 'add': fused_attn = (torch.softmax(attn, 3) + torch.softmax(loc_attn, 3)) / 2 else: if self.spatial_attn_fusion in ['mul', 'cond']: fused_attn = torch.log(torch.clamp(loc_attn, min=1e-6)) + attn else: fused_attn = loc_attn + attn fused_attn = torch.softmax(fused_attn, 3) assert torch.sum(torch.isnan(fused_attn) == 0), print(fused_attn) output = torch.einsum('hblt,hbtv->hblv', fused_attn, v) output = einops.rearrange(output, 'head b l v -> b l (head v)') output = self.dropout(self.fc(output)) output = self.layer_norm(output + residual) return output, fused_attn class TransformerSpatialEncoderLayer(TransformerEncoderLayer): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", spatial_multihead=True, spatial_dim=5, spatial_attn_fusion='mul' ): super().__init__( d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation ) del self.self_attn self.self_attn = MultiHeadAttentionSpatial( d_model, nhead, dropout=dropout, spatial_multihead=spatial_multihead, spatial_dim=spatial_dim, spatial_attn_fusion=spatial_attn_fusion, ) def forward( self, tgt, tgt_pairwise_locs, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, ): tgt2 = tgt tgt2, self_attn_matrices = self.self_attn( tgt2, tgt2, tgt2, tgt_pairwise_locs, key_padding_mask=tgt_key_padding_mask ) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) return tgt, self_attn_matrices def _init_weights_bert(module, std=0.02): """ Huggingface transformer weight initialization, most commonly for bert initialization """ if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def generate_fourier_features(pos, num_bands=10, max_freq=15, concat_pos=True, sine_only=False): # Input: B, N, C # Output: B, N, C' batch_size = pos.shape[0] device = pos.device min_freq = 1.0 # Nyquist frequency at the target resolution: freq_bands = torch.linspace(start=min_freq, end=max_freq, steps=num_bands, device=device) # Get frequency bands for each spatial dimension. # Output is size [n, d * num_bands] per_pos_features = pos.unsqueeze(-1).repeat(1, 1, 1, num_bands) * freq_bands per_pos_features = torch.reshape( per_pos_features, [batch_size, -1, np.prod(per_pos_features.shape[2:])]) if sine_only: # Output is size [n, d * num_bands] per_pos_features = torch.sin(np.pi * (per_pos_features)) else: # Output is size [n, 2 * d * num_bands] per_pos_features = torch.cat( [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1 ) # Concatenate the raw input positions. if concat_pos: # Adds d bands to the encoding. per_pos_features = torch.cat( [pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1) return per_pos_features class OSE3D(nn.Module): # Open-vocabulary, Spatial-attention, Embodied-token, 3D-agent def __init__(self, use_spatial_attn=True, use_embodied_token=False, hidden_dim=256, fourier_size=84, spatial_encoder={ "num_attention_heads": 8, "dim_feedforward": 2048, "dropout": 0.1, "activation": "gelu", "spatial_dim": 5, "spatial_multihead": True, "spatial_attn_fusion": "cond", "num_layers": 3, "pairwise_rel_type": "center", "spatial_dist_norm": True, "obj_loc_encoding": "same_all", "dim_loc": 6, }): super().__init__() self.use_spatial_attn = use_spatial_attn # spatial attention self.use_embodied_token = use_embodied_token # embodied token # pcd backbone # self.obj_encoder = PointcloudBackbone(backbone) self.obj_proj = nn.Linear(768, hidden_dim) # embodied token if self.use_embodied_token: self.anchor_feat = nn.Parameter(torch.zeros(1, 1, hidden_dim)) self.anchor_size = nn.Parameter(torch.ones(1, 1, 3)) self.orient_encoder = nn.Linear(fourier_size, hidden_dim) self.obj_type_embed = nn.Embedding(2, hidden_dim) # spatial encoder if self.use_spatial_attn: spatial_encoder_layer = TransformerSpatialEncoderLayer( d_model=hidden_dim, nhead=spatial_encoder['num_attention_heads'], dim_feedforward=spatial_encoder['dim_feedforward'], dropout=spatial_encoder['dropout'], activation=spatial_encoder['activation'], spatial_dim=spatial_encoder['spatial_dim'], spatial_multihead=spatial_encoder['spatial_multihead'], spatial_attn_fusion=spatial_encoder['spatial_attn_fusion'], ) else: spatial_encoder_layer = TransformerEncoderLayer( d_model=hidden_dim, nhead=spatial_encoder['num_attention_heads'], dim_feedforward=spatial_encoder['dim_feedforward'], dropout=spatial_encoder['dropout'], activation=spatial_encoder['activation'], ) self.spatial_encoder = layer_repeat( spatial_encoder_layer, spatial_encoder['num_layers'], ) self.pairwise_rel_type = spatial_encoder['pairwise_rel_type'] self.spatial_dist_norm = spatial_encoder['spatial_dist_norm'] self.spatial_dim = spatial_encoder['spatial_dim'] self.obj_loc_encoding = spatial_encoder['obj_loc_encoding'] # location encoding if self.obj_loc_encoding in ['same_0', 'same_all']: num_loc_layers = 1 elif self.obj_loc_encoding == 'diff_all': num_loc_layers = spatial_encoder['num_layers'] loc_layer = nn.Sequential( nn.Linear(spatial_encoder['dim_loc'], hidden_dim), nn.LayerNorm(hidden_dim), ) self.loc_layers = layer_repeat(loc_layer, num_loc_layers) # only initialize spatial encoder and loc layers self.spatial_encoder.apply(_init_weights_bert) self.loc_layers.apply(_init_weights_bert) if self.use_embodied_token: nn.init.normal_(self.anchor_feat, std=0.02) @property def device(self): return list(self.parameters())[0].device def forward(self, data_dict): """ data_dict requires keys: obj_fts: (B, N, P, 6), xyz + rgb obj_masks: (B, N), 1 valid and 0 masked obj_locs: (B, N, 6), xyz + whd anchor_locs: (B, 3) anchor_orientation: (B, C) """ # obj_feats = self.obj_encoder(data_dict['obj_fts']) obj_feats = data_dict['obj_feats'] obj_feats = self.obj_proj(obj_feats) obj_masks = ~data_dict['obj_masks'] # flipped due to different convention of TransformerEncoder B, N = obj_feats.shape[:2] device = obj_feats.device obj_type_ids = torch.zeros((B, N), dtype=torch.long, device=device) obj_type_embeds = self.obj_type_embed(obj_type_ids) if self.use_embodied_token: # anchor feature anchor_orient = data_dict['anchor_orientation'].unsqueeze(1) anchor_orient_feat = self.orient_encoder(generate_fourier_features(anchor_orient)) anchor_feat = self.anchor_feat + anchor_orient_feat anchor_mask = torch.zeros((B, 1), dtype=bool, device=device) # anchor loc (3) + size (3) anchor_loc = torch.cat( [data_dict['anchor_locs'].unsqueeze(1), self.anchor_size.expand(B, -1, -1).to(device)], dim=-1 ) # anchor type anchor_type_id = torch.ones((B, 1), dtype=torch.long, device=device) anchor_type_embed = self.obj_type_embed(anchor_type_id) # fuse anchor and objs all_obj_feats = torch.cat([anchor_feat, obj_feats], dim=1) all_obj_masks = torch.cat((anchor_mask, obj_masks), dim=1) all_obj_locs = torch.cat([anchor_loc, data_dict['obj_locs']], dim=1) all_obj_type_embeds = torch.cat((anchor_type_embed, obj_type_embeds), dim=1) else: all_obj_feats = obj_feats all_obj_masks = obj_masks all_obj_locs = data_dict['obj_locs'] all_obj_type_embeds = obj_type_embeds all_obj_feats = all_obj_feats + all_obj_type_embeds # call spatial encoder if self.use_spatial_attn: pairwise_locs = calc_pairwise_locs( all_obj_locs[:, :, :3], all_obj_locs[:, :, 3:], pairwise_rel_type=self.pairwise_rel_type, spatial_dist_norm=self.spatial_dist_norm, spatial_dim=self.spatial_dim, ) for i, pc_layer in enumerate(self.spatial_encoder): if self.obj_loc_encoding == 'diff_all': query_pos = self.loc_layers[i](all_obj_locs) else: query_pos = self.loc_layers[0](all_obj_locs) if not (self.obj_loc_encoding == 'same_0' and i > 0): all_obj_feats = all_obj_feats + query_pos if self.use_spatial_attn: all_obj_feats, _ = pc_layer( all_obj_feats, pairwise_locs, tgt_key_padding_mask=all_obj_masks ) else: all_obj_feats, _ = pc_layer( all_obj_feats, tgt_key_padding_mask=all_obj_masks ) data_dict['obj_tokens'] = all_obj_feats data_dict['obj_masks'] = ~all_obj_masks # ###feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth') # data_dict['obj_tokens'] = torch.load('assets/inputs/scene0350_00/obj_tokens.pth') # data_dict['obj_masks'] = torch.load('assets/inputs/scene0350_00/obj_masks.pth') return data_dict