Spaces:
Runtime error
Runtime error
import torch | |
import einops | |
import numpy as np | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
from typing import Optional | |
from leo.utils import get_activation_fn, layer_repeat, calc_pairwise_locs | |
def disabled_train(self, mode=True): | |
""" | |
Overwrite model.train with this function to make sure train/eval mode does not change anymore | |
""" | |
return self | |
class TransformerEncoderLayer(nn.Module): | |
def __init__(self, d_model, nhead, dim_feedforward=2048, batch_first=True, dropout=0.1, activation="relu", prenorm=False): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention( | |
d_model, nhead, dropout=dropout, batch_first=batch_first | |
) | |
# Implementation of Feedforward modules | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.activation = get_activation_fn(activation) | |
self.prenorm = prenorm | |
def forward( | |
self, tgt, tgt_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
): | |
tgt2 = tgt | |
if self.prenorm: | |
tgt2 = self.norm1(tgt2) | |
tgt2, self_attn_matrices = self.self_attn( | |
query=tgt2, key=tgt2, value=tgt2, attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask | |
) | |
tgt = tgt + self.dropout1(tgt2) | |
if not self.prenorm: | |
tgt = self.norm1(tgt) | |
if self.prenorm: | |
tgt = self.norm2(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout2(tgt2) | |
if not self.prenorm: | |
tgt = self.norm2(tgt) | |
return tgt, self_attn_matrices | |
class MultiHeadAttentionSpatial(nn.Module): | |
def __init__( | |
self, d_model, n_head, dropout=0.1, spatial_multihead=True, spatial_dim=5, | |
spatial_attn_fusion='mul', | |
): | |
super().__init__() | |
assert d_model % n_head == 0, 'd_model: %d, n_head: %d' % (d_model, n_head) | |
self.n_head = n_head | |
self.d_model = d_model | |
self.d_per_head = d_model // n_head | |
self.spatial_multihead = spatial_multihead | |
self.spatial_dim = spatial_dim | |
self.spatial_attn_fusion = spatial_attn_fusion | |
self.w_qs = nn.Linear(d_model, d_model) | |
self.w_ks = nn.Linear(d_model, d_model) | |
self.w_vs = nn.Linear(d_model, d_model) | |
self.fc = nn.Linear(d_model, d_model) | |
self.dropout = nn.Dropout(p=dropout) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.spatial_n_head = n_head if spatial_multihead else 1 | |
if self.spatial_attn_fusion in ['mul', 'bias', 'add']: | |
self.pairwise_loc_fc = nn.Linear(spatial_dim, self.spatial_n_head) | |
elif self.spatial_attn_fusion == 'ctx': | |
self.pairwise_loc_fc = nn.Linear(spatial_dim, d_model) | |
elif self.spatial_attn_fusion == 'cond': | |
self.lang_cond_fc = nn.Linear(d_model, self.spatial_n_head * (spatial_dim + 1)) | |
else: | |
raise NotImplementedError('unsupported spatial_attn_fusion %s' % (self.spatial_attn_fusion)) | |
def forward(self, q, k, v, pairwise_locs, key_padding_mask=None, txt_embeds=None): | |
residual = q | |
q = einops.rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head) | |
k = einops.rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head) | |
v = einops.rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head) | |
attn = torch.einsum('hblk,hbtk->hblt', q, k) / np.sqrt(q.shape[-1]) | |
if self.spatial_attn_fusion in ['mul', 'bias', 'add']: | |
loc_attn = self.pairwise_loc_fc(pairwise_locs) | |
loc_attn = einops.rearrange(loc_attn, 'b l t h -> h b l t') | |
if self.spatial_attn_fusion == 'mul': | |
loc_attn = F.relu(loc_attn) | |
if not self.spatial_multihead: | |
loc_attn = einops.repeat(loc_attn, 'h b l t -> (h nh) b l t', nh=self.n_head) | |
elif self.spatial_attn_fusion == 'ctx': | |
loc_attn = self.pairwise_loc_fc(pairwise_locs) | |
loc_attn = einops.rearrange(loc_attn, 'b l t (h k) -> h b l t k', h=self.n_head) | |
loc_attn = torch.einsum('hblk,hbltk->hblt', q, loc_attn) / np.sqrt(q.shape[-1]) | |
elif self.spatial_attn_fusion == 'cond': | |
spatial_weights = self.lang_cond_fc(residual) | |
spatial_weights = einops.rearrange(spatial_weights, 'b l (h d) -> h b l d', h=self.spatial_n_head, | |
d=self.spatial_dim + 1) | |
if self.spatial_n_head == 1: | |
spatial_weights = einops.repeat(spatial_weights, '1 b l d -> h b l d', h=self.n_head) | |
spatial_bias = spatial_weights[..., :1] | |
spatial_weights = spatial_weights[..., 1:] | |
loc_attn = torch.einsum('hbld,bltd->hblt', spatial_weights, pairwise_locs) + spatial_bias | |
loc_attn = torch.sigmoid(loc_attn) | |
if key_padding_mask is not None: | |
mask = einops.repeat(key_padding_mask, 'b t -> h b l t', h=self.n_head, l=q.size(2)) | |
attn = attn.masked_fill(mask, -np.inf) | |
if self.spatial_attn_fusion in ['mul', 'cond']: | |
loc_attn = loc_attn.masked_fill(mask, 0) | |
else: | |
loc_attn = loc_attn.masked_fill(mask, -np.inf) | |
if self.spatial_attn_fusion == 'add': | |
fused_attn = (torch.softmax(attn, 3) + torch.softmax(loc_attn, 3)) / 2 | |
else: | |
if self.spatial_attn_fusion in ['mul', 'cond']: | |
fused_attn = torch.log(torch.clamp(loc_attn, min=1e-6)) + attn | |
else: | |
fused_attn = loc_attn + attn | |
fused_attn = torch.softmax(fused_attn, 3) | |
assert torch.sum(torch.isnan(fused_attn) == 0), print(fused_attn) | |
output = torch.einsum('hblt,hbtv->hblv', fused_attn, v) | |
output = einops.rearrange(output, 'head b l v -> b l (head v)') | |
output = self.dropout(self.fc(output)) | |
output = self.layer_norm(output + residual) | |
return output, fused_attn | |
class TransformerSpatialEncoderLayer(TransformerEncoderLayer): | |
def __init__( | |
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", | |
spatial_multihead=True, spatial_dim=5, spatial_attn_fusion='mul' | |
): | |
super().__init__( | |
d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation | |
) | |
del self.self_attn | |
self.self_attn = MultiHeadAttentionSpatial( | |
d_model, nhead, dropout=dropout, | |
spatial_multihead=spatial_multihead, | |
spatial_dim=spatial_dim, | |
spatial_attn_fusion=spatial_attn_fusion, | |
) | |
def forward( | |
self, tgt, tgt_pairwise_locs, | |
tgt_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
): | |
tgt2 = tgt | |
tgt2, self_attn_matrices = self.self_attn( | |
tgt2, tgt2, tgt2, tgt_pairwise_locs, | |
key_padding_mask=tgt_key_padding_mask | |
) | |
tgt = tgt + self.dropout1(tgt2) | |
tgt = self.norm1(tgt) | |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
tgt = tgt + self.dropout2(tgt2) | |
tgt = self.norm2(tgt) | |
return tgt, self_attn_matrices | |
def _init_weights_bert(module, std=0.02): | |
""" | |
Huggingface transformer weight initialization, | |
most commonly for bert initialization | |
""" | |
if isinstance(module, nn.Linear): | |
# Slightly different from the TF version which uses truncated_normal for initialization | |
# cf https://github.com/pytorch/pytorch/pull/5617 | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def generate_fourier_features(pos, num_bands=10, max_freq=15, concat_pos=True, sine_only=False): | |
# Input: B, N, C | |
# Output: B, N, C' | |
batch_size = pos.shape[0] | |
device = pos.device | |
min_freq = 1.0 | |
# Nyquist frequency at the target resolution: | |
freq_bands = torch.linspace(start=min_freq, end=max_freq, steps=num_bands, device=device) | |
# Get frequency bands for each spatial dimension. | |
# Output is size [n, d * num_bands] | |
per_pos_features = pos.unsqueeze(-1).repeat(1, 1, 1, num_bands) * freq_bands | |
per_pos_features = torch.reshape( | |
per_pos_features, [batch_size, -1, np.prod(per_pos_features.shape[2:])]) | |
if sine_only: | |
# Output is size [n, d * num_bands] | |
per_pos_features = torch.sin(np.pi * (per_pos_features)) | |
else: | |
# Output is size [n, 2 * d * num_bands] | |
per_pos_features = torch.cat( | |
[torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1 | |
) | |
# Concatenate the raw input positions. | |
if concat_pos: | |
# Adds d bands to the encoding. | |
per_pos_features = torch.cat( | |
[pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1) | |
return per_pos_features | |
class OSE3D(nn.Module): | |
# Open-vocabulary, Spatial-attention, Embodied-token, 3D-agent | |
def __init__(self, use_spatial_attn=True, use_embodied_token=False, hidden_dim=256, fourier_size=84, spatial_encoder={ | |
"num_attention_heads": 8, | |
"dim_feedforward": 2048, | |
"dropout": 0.1, | |
"activation": "gelu", | |
"spatial_dim": 5, | |
"spatial_multihead": True, | |
"spatial_attn_fusion": "cond", | |
"num_layers": 3, | |
"pairwise_rel_type": "center", | |
"spatial_dist_norm": True, | |
"obj_loc_encoding": "same_all", | |
"dim_loc": 6, | |
}): | |
super().__init__() | |
self.use_spatial_attn = use_spatial_attn # spatial attention | |
self.use_embodied_token = use_embodied_token # embodied token | |
# pcd backbone | |
# self.obj_encoder = PointcloudBackbone(backbone) | |
self.obj_proj = nn.Linear(768, hidden_dim) | |
# embodied token | |
if self.use_embodied_token: | |
self.anchor_feat = nn.Parameter(torch.zeros(1, 1, hidden_dim)) | |
self.anchor_size = nn.Parameter(torch.ones(1, 1, 3)) | |
self.orient_encoder = nn.Linear(fourier_size, hidden_dim) | |
self.obj_type_embed = nn.Embedding(2, hidden_dim) | |
# spatial encoder | |
if self.use_spatial_attn: | |
spatial_encoder_layer = TransformerSpatialEncoderLayer( | |
d_model=hidden_dim, | |
nhead=spatial_encoder['num_attention_heads'], | |
dim_feedforward=spatial_encoder['dim_feedforward'], | |
dropout=spatial_encoder['dropout'], | |
activation=spatial_encoder['activation'], | |
spatial_dim=spatial_encoder['spatial_dim'], | |
spatial_multihead=spatial_encoder['spatial_multihead'], | |
spatial_attn_fusion=spatial_encoder['spatial_attn_fusion'], | |
) | |
else: | |
spatial_encoder_layer = TransformerEncoderLayer( | |
d_model=hidden_dim, | |
nhead=spatial_encoder['num_attention_heads'], | |
dim_feedforward=spatial_encoder['dim_feedforward'], | |
dropout=spatial_encoder['dropout'], | |
activation=spatial_encoder['activation'], | |
) | |
self.spatial_encoder = layer_repeat( | |
spatial_encoder_layer, | |
spatial_encoder['num_layers'], | |
) | |
self.pairwise_rel_type = spatial_encoder['pairwise_rel_type'] | |
self.spatial_dist_norm = spatial_encoder['spatial_dist_norm'] | |
self.spatial_dim = spatial_encoder['spatial_dim'] | |
self.obj_loc_encoding = spatial_encoder['obj_loc_encoding'] | |
# location encoding | |
if self.obj_loc_encoding in ['same_0', 'same_all']: | |
num_loc_layers = 1 | |
elif self.obj_loc_encoding == 'diff_all': | |
num_loc_layers = spatial_encoder['num_layers'] | |
loc_layer = nn.Sequential( | |
nn.Linear(spatial_encoder['dim_loc'], hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
) | |
self.loc_layers = layer_repeat(loc_layer, num_loc_layers) | |
# only initialize spatial encoder and loc layers | |
self.spatial_encoder.apply(_init_weights_bert) | |
self.loc_layers.apply(_init_weights_bert) | |
if self.use_embodied_token: | |
nn.init.normal_(self.anchor_feat, std=0.02) | |
def device(self): | |
return list(self.parameters())[0].device | |
def forward(self, data_dict): | |
""" | |
data_dict requires keys: | |
obj_fts: (B, N, P, 6), xyz + rgb | |
obj_masks: (B, N), 1 valid and 0 masked | |
obj_locs: (B, N, 6), xyz + whd | |
anchor_locs: (B, 3) | |
anchor_orientation: (B, C) | |
""" | |
# obj_feats = self.obj_encoder(data_dict['obj_fts']) | |
obj_feats = data_dict['obj_feats'] | |
obj_feats = self.obj_proj(obj_feats) | |
obj_masks = ~data_dict['obj_masks'] # flipped due to different convention of TransformerEncoder | |
B, N = obj_feats.shape[:2] | |
device = obj_feats.device | |
obj_type_ids = torch.zeros((B, N), dtype=torch.long, device=device) | |
obj_type_embeds = self.obj_type_embed(obj_type_ids) | |
if self.use_embodied_token: | |
# anchor feature | |
anchor_orient = data_dict['anchor_orientation'].unsqueeze(1) | |
anchor_orient_feat = self.orient_encoder(generate_fourier_features(anchor_orient)) | |
anchor_feat = self.anchor_feat + anchor_orient_feat | |
anchor_mask = torch.zeros((B, 1), dtype=bool, device=device) | |
# anchor loc (3) + size (3) | |
anchor_loc = torch.cat( | |
[data_dict['anchor_locs'].unsqueeze(1), self.anchor_size.expand(B, -1, -1).to(device)], dim=-1 | |
) | |
# anchor type | |
anchor_type_id = torch.ones((B, 1), dtype=torch.long, device=device) | |
anchor_type_embed = self.obj_type_embed(anchor_type_id) | |
# fuse anchor and objs | |
all_obj_feats = torch.cat([anchor_feat, obj_feats], dim=1) | |
all_obj_masks = torch.cat((anchor_mask, obj_masks), dim=1) | |
all_obj_locs = torch.cat([anchor_loc, data_dict['obj_locs']], dim=1) | |
all_obj_type_embeds = torch.cat((anchor_type_embed, obj_type_embeds), dim=1) | |
else: | |
all_obj_feats = obj_feats | |
all_obj_masks = obj_masks | |
all_obj_locs = data_dict['obj_locs'] | |
all_obj_type_embeds = obj_type_embeds | |
all_obj_feats = all_obj_feats + all_obj_type_embeds | |
# call spatial encoder | |
if self.use_spatial_attn: | |
pairwise_locs = calc_pairwise_locs( | |
all_obj_locs[:, :, :3], | |
all_obj_locs[:, :, 3:], | |
pairwise_rel_type=self.pairwise_rel_type, | |
spatial_dist_norm=self.spatial_dist_norm, | |
spatial_dim=self.spatial_dim, | |
) | |
for i, pc_layer in enumerate(self.spatial_encoder): | |
if self.obj_loc_encoding == 'diff_all': | |
query_pos = self.loc_layers[i](all_obj_locs) | |
else: | |
query_pos = self.loc_layers[0](all_obj_locs) | |
if not (self.obj_loc_encoding == 'same_0' and i > 0): | |
all_obj_feats = all_obj_feats + query_pos | |
if self.use_spatial_attn: | |
all_obj_feats, _ = pc_layer( | |
all_obj_feats, pairwise_locs, | |
tgt_key_padding_mask=all_obj_masks | |
) | |
else: | |
all_obj_feats, _ = pc_layer( | |
all_obj_feats, | |
tgt_key_padding_mask=all_obj_masks | |
) | |
data_dict['obj_tokens'] = all_obj_feats | |
data_dict['obj_masks'] = ~all_obj_masks | |
# ###feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth') | |
# data_dict['obj_tokens'] = torch.load('assets/inputs/scene0350_00/obj_tokens.pth') | |
# data_dict['obj_masks'] = torch.load('assets/inputs/scene0350_00/obj_masks.pth') | |
return data_dict |