SG3D-Demo / leo /pcd_encoder.py
zfzhang-thu
non-LFS commit
9de012e
raw
history blame
16.7 kB
import torch
import einops
import numpy as np
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Optional
from leo.utils import get_activation_fn, layer_repeat, calc_pairwise_locs
def disabled_train(self, mode=True):
"""
Overwrite model.train with this function to make sure train/eval mode does not change anymore
"""
return self
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, batch_first=True, dropout=0.1, activation="relu", prenorm=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# Implementation of Feedforward modules
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = get_activation_fn(activation)
self.prenorm = prenorm
def forward(
self, tgt, tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
):
tgt2 = tgt
if self.prenorm:
tgt2 = self.norm1(tgt2)
tgt2, self_attn_matrices = self.self_attn(
query=tgt2, key=tgt2, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)
tgt = tgt + self.dropout1(tgt2)
if not self.prenorm:
tgt = self.norm1(tgt)
if self.prenorm:
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
if not self.prenorm:
tgt = self.norm2(tgt)
return tgt, self_attn_matrices
class MultiHeadAttentionSpatial(nn.Module):
def __init__(
self, d_model, n_head, dropout=0.1, spatial_multihead=True, spatial_dim=5,
spatial_attn_fusion='mul',
):
super().__init__()
assert d_model % n_head == 0, 'd_model: %d, n_head: %d' % (d_model, n_head)
self.n_head = n_head
self.d_model = d_model
self.d_per_head = d_model // n_head
self.spatial_multihead = spatial_multihead
self.spatial_dim = spatial_dim
self.spatial_attn_fusion = spatial_attn_fusion
self.w_qs = nn.Linear(d_model, d_model)
self.w_ks = nn.Linear(d_model, d_model)
self.w_vs = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
self.spatial_n_head = n_head if spatial_multihead else 1
if self.spatial_attn_fusion in ['mul', 'bias', 'add']:
self.pairwise_loc_fc = nn.Linear(spatial_dim, self.spatial_n_head)
elif self.spatial_attn_fusion == 'ctx':
self.pairwise_loc_fc = nn.Linear(spatial_dim, d_model)
elif self.spatial_attn_fusion == 'cond':
self.lang_cond_fc = nn.Linear(d_model, self.spatial_n_head * (spatial_dim + 1))
else:
raise NotImplementedError('unsupported spatial_attn_fusion %s' % (self.spatial_attn_fusion))
def forward(self, q, k, v, pairwise_locs, key_padding_mask=None, txt_embeds=None):
residual = q
q = einops.rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head)
k = einops.rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head)
v = einops.rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head)
attn = torch.einsum('hblk,hbtk->hblt', q, k) / np.sqrt(q.shape[-1])
if self.spatial_attn_fusion in ['mul', 'bias', 'add']:
loc_attn = self.pairwise_loc_fc(pairwise_locs)
loc_attn = einops.rearrange(loc_attn, 'b l t h -> h b l t')
if self.spatial_attn_fusion == 'mul':
loc_attn = F.relu(loc_attn)
if not self.spatial_multihead:
loc_attn = einops.repeat(loc_attn, 'h b l t -> (h nh) b l t', nh=self.n_head)
elif self.spatial_attn_fusion == 'ctx':
loc_attn = self.pairwise_loc_fc(pairwise_locs)
loc_attn = einops.rearrange(loc_attn, 'b l t (h k) -> h b l t k', h=self.n_head)
loc_attn = torch.einsum('hblk,hbltk->hblt', q, loc_attn) / np.sqrt(q.shape[-1])
elif self.spatial_attn_fusion == 'cond':
spatial_weights = self.lang_cond_fc(residual)
spatial_weights = einops.rearrange(spatial_weights, 'b l (h d) -> h b l d', h=self.spatial_n_head,
d=self.spatial_dim + 1)
if self.spatial_n_head == 1:
spatial_weights = einops.repeat(spatial_weights, '1 b l d -> h b l d', h=self.n_head)
spatial_bias = spatial_weights[..., :1]
spatial_weights = spatial_weights[..., 1:]
loc_attn = torch.einsum('hbld,bltd->hblt', spatial_weights, pairwise_locs) + spatial_bias
loc_attn = torch.sigmoid(loc_attn)
if key_padding_mask is not None:
mask = einops.repeat(key_padding_mask, 'b t -> h b l t', h=self.n_head, l=q.size(2))
attn = attn.masked_fill(mask, -np.inf)
if self.spatial_attn_fusion in ['mul', 'cond']:
loc_attn = loc_attn.masked_fill(mask, 0)
else:
loc_attn = loc_attn.masked_fill(mask, -np.inf)
if self.spatial_attn_fusion == 'add':
fused_attn = (torch.softmax(attn, 3) + torch.softmax(loc_attn, 3)) / 2
else:
if self.spatial_attn_fusion in ['mul', 'cond']:
fused_attn = torch.log(torch.clamp(loc_attn, min=1e-6)) + attn
else:
fused_attn = loc_attn + attn
fused_attn = torch.softmax(fused_attn, 3)
assert torch.sum(torch.isnan(fused_attn) == 0), print(fused_attn)
output = torch.einsum('hblt,hbtv->hblv', fused_attn, v)
output = einops.rearrange(output, 'head b l v -> b l (head v)')
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, fused_attn
class TransformerSpatialEncoderLayer(TransformerEncoderLayer):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
spatial_multihead=True, spatial_dim=5, spatial_attn_fusion='mul'
):
super().__init__(
d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation
)
del self.self_attn
self.self_attn = MultiHeadAttentionSpatial(
d_model, nhead, dropout=dropout,
spatial_multihead=spatial_multihead,
spatial_dim=spatial_dim,
spatial_attn_fusion=spatial_attn_fusion,
)
def forward(
self, tgt, tgt_pairwise_locs,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
):
tgt2 = tgt
tgt2, self_attn_matrices = self.self_attn(
tgt2, tgt2, tgt2, tgt_pairwise_locs,
key_padding_mask=tgt_key_padding_mask
)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
return tgt, self_attn_matrices
def _init_weights_bert(module, std=0.02):
"""
Huggingface transformer weight initialization,
most commonly for bert initialization
"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def generate_fourier_features(pos, num_bands=10, max_freq=15, concat_pos=True, sine_only=False):
# Input: B, N, C
# Output: B, N, C'
batch_size = pos.shape[0]
device = pos.device
min_freq = 1.0
# Nyquist frequency at the target resolution:
freq_bands = torch.linspace(start=min_freq, end=max_freq, steps=num_bands, device=device)
# Get frequency bands for each spatial dimension.
# Output is size [n, d * num_bands]
per_pos_features = pos.unsqueeze(-1).repeat(1, 1, 1, num_bands) * freq_bands
per_pos_features = torch.reshape(
per_pos_features, [batch_size, -1, np.prod(per_pos_features.shape[2:])])
if sine_only:
# Output is size [n, d * num_bands]
per_pos_features = torch.sin(np.pi * (per_pos_features))
else:
# Output is size [n, 2 * d * num_bands]
per_pos_features = torch.cat(
[torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1
)
# Concatenate the raw input positions.
if concat_pos:
# Adds d bands to the encoding.
per_pos_features = torch.cat(
[pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1)
return per_pos_features
class OSE3D(nn.Module):
# Open-vocabulary, Spatial-attention, Embodied-token, 3D-agent
def __init__(self, use_spatial_attn=True, use_embodied_token=False, hidden_dim=256, fourier_size=84, spatial_encoder={
"num_attention_heads": 8,
"dim_feedforward": 2048,
"dropout": 0.1,
"activation": "gelu",
"spatial_dim": 5,
"spatial_multihead": True,
"spatial_attn_fusion": "cond",
"num_layers": 3,
"pairwise_rel_type": "center",
"spatial_dist_norm": True,
"obj_loc_encoding": "same_all",
"dim_loc": 6,
}):
super().__init__()
self.use_spatial_attn = use_spatial_attn # spatial attention
self.use_embodied_token = use_embodied_token # embodied token
# pcd backbone
# self.obj_encoder = PointcloudBackbone(backbone)
self.obj_proj = nn.Linear(768, hidden_dim)
# embodied token
if self.use_embodied_token:
self.anchor_feat = nn.Parameter(torch.zeros(1, 1, hidden_dim))
self.anchor_size = nn.Parameter(torch.ones(1, 1, 3))
self.orient_encoder = nn.Linear(fourier_size, hidden_dim)
self.obj_type_embed = nn.Embedding(2, hidden_dim)
# spatial encoder
if self.use_spatial_attn:
spatial_encoder_layer = TransformerSpatialEncoderLayer(
d_model=hidden_dim,
nhead=spatial_encoder['num_attention_heads'],
dim_feedforward=spatial_encoder['dim_feedforward'],
dropout=spatial_encoder['dropout'],
activation=spatial_encoder['activation'],
spatial_dim=spatial_encoder['spatial_dim'],
spatial_multihead=spatial_encoder['spatial_multihead'],
spatial_attn_fusion=spatial_encoder['spatial_attn_fusion'],
)
else:
spatial_encoder_layer = TransformerEncoderLayer(
d_model=hidden_dim,
nhead=spatial_encoder['num_attention_heads'],
dim_feedforward=spatial_encoder['dim_feedforward'],
dropout=spatial_encoder['dropout'],
activation=spatial_encoder['activation'],
)
self.spatial_encoder = layer_repeat(
spatial_encoder_layer,
spatial_encoder['num_layers'],
)
self.pairwise_rel_type = spatial_encoder['pairwise_rel_type']
self.spatial_dist_norm = spatial_encoder['spatial_dist_norm']
self.spatial_dim = spatial_encoder['spatial_dim']
self.obj_loc_encoding = spatial_encoder['obj_loc_encoding']
# location encoding
if self.obj_loc_encoding in ['same_0', 'same_all']:
num_loc_layers = 1
elif self.obj_loc_encoding == 'diff_all':
num_loc_layers = spatial_encoder['num_layers']
loc_layer = nn.Sequential(
nn.Linear(spatial_encoder['dim_loc'], hidden_dim),
nn.LayerNorm(hidden_dim),
)
self.loc_layers = layer_repeat(loc_layer, num_loc_layers)
# only initialize spatial encoder and loc layers
self.spatial_encoder.apply(_init_weights_bert)
self.loc_layers.apply(_init_weights_bert)
if self.use_embodied_token:
nn.init.normal_(self.anchor_feat, std=0.02)
@property
def device(self):
return list(self.parameters())[0].device
def forward(self, data_dict):
"""
data_dict requires keys:
obj_fts: (B, N, P, 6), xyz + rgb
obj_masks: (B, N), 1 valid and 0 masked
obj_locs: (B, N, 6), xyz + whd
anchor_locs: (B, 3)
anchor_orientation: (B, C)
"""
# obj_feats = self.obj_encoder(data_dict['obj_fts'])
obj_feats = data_dict['obj_feats']
obj_feats = self.obj_proj(obj_feats)
obj_masks = ~data_dict['obj_masks'] # flipped due to different convention of TransformerEncoder
B, N = obj_feats.shape[:2]
device = obj_feats.device
obj_type_ids = torch.zeros((B, N), dtype=torch.long, device=device)
obj_type_embeds = self.obj_type_embed(obj_type_ids)
if self.use_embodied_token:
# anchor feature
anchor_orient = data_dict['anchor_orientation'].unsqueeze(1)
anchor_orient_feat = self.orient_encoder(generate_fourier_features(anchor_orient))
anchor_feat = self.anchor_feat + anchor_orient_feat
anchor_mask = torch.zeros((B, 1), dtype=bool, device=device)
# anchor loc (3) + size (3)
anchor_loc = torch.cat(
[data_dict['anchor_locs'].unsqueeze(1), self.anchor_size.expand(B, -1, -1).to(device)], dim=-1
)
# anchor type
anchor_type_id = torch.ones((B, 1), dtype=torch.long, device=device)
anchor_type_embed = self.obj_type_embed(anchor_type_id)
# fuse anchor and objs
all_obj_feats = torch.cat([anchor_feat, obj_feats], dim=1)
all_obj_masks = torch.cat((anchor_mask, obj_masks), dim=1)
all_obj_locs = torch.cat([anchor_loc, data_dict['obj_locs']], dim=1)
all_obj_type_embeds = torch.cat((anchor_type_embed, obj_type_embeds), dim=1)
else:
all_obj_feats = obj_feats
all_obj_masks = obj_masks
all_obj_locs = data_dict['obj_locs']
all_obj_type_embeds = obj_type_embeds
all_obj_feats = all_obj_feats + all_obj_type_embeds
# call spatial encoder
if self.use_spatial_attn:
pairwise_locs = calc_pairwise_locs(
all_obj_locs[:, :, :3],
all_obj_locs[:, :, 3:],
pairwise_rel_type=self.pairwise_rel_type,
spatial_dist_norm=self.spatial_dist_norm,
spatial_dim=self.spatial_dim,
)
for i, pc_layer in enumerate(self.spatial_encoder):
if self.obj_loc_encoding == 'diff_all':
query_pos = self.loc_layers[i](all_obj_locs)
else:
query_pos = self.loc_layers[0](all_obj_locs)
if not (self.obj_loc_encoding == 'same_0' and i > 0):
all_obj_feats = all_obj_feats + query_pos
if self.use_spatial_attn:
all_obj_feats, _ = pc_layer(
all_obj_feats, pairwise_locs,
tgt_key_padding_mask=all_obj_masks
)
else:
all_obj_feats, _ = pc_layer(
all_obj_feats,
tgt_key_padding_mask=all_obj_masks
)
data_dict['obj_tokens'] = all_obj_feats
data_dict['obj_masks'] = ~all_obj_masks
# ###feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth')
# data_dict['obj_tokens'] = torch.load('assets/inputs/scene0350_00/obj_tokens.pth')
# data_dict['obj_masks'] = torch.load('assets/inputs/scene0350_00/obj_masks.pth')
return data_dict