import torch import torch.nn as nn import math import torch.nn.functional as F from StructDiffusion.models.encoders import EncoderMLP, DropoutSampler, SinusoidalPositionEmbeddings from torch.nn import TransformerEncoder, TransformerEncoderLayer from StructDiffusion.models.point_transformer import PointTransformerEncoderSmall from StructDiffusion.models.point_transformer_large import PointTransformerCls class TransformerDiffusionModel(torch.nn.Module): def __init__(self, vocab_size, # transformer params encoder_input_dim=256, num_attention_heads=8, encoder_hidden_dim=16, encoder_dropout=0.1, encoder_activation="relu", encoder_num_layers=8, # output head params structure_dropout=0.0, object_dropout=0.0, # pc encoder params ignore_rgb=False, pc_emb_dim=256, posed_pc_emb_dim=80, pose_emb_dim=80, max_seq_size=7, max_token_type_size=4, seq_pos_emb_dim=8, seq_type_emb_dim=8, word_emb_dim=160, time_emb_dim=80, use_virtual_structure_frame=True, use_sentence_embedding=False, sentence_embedding_dim=None, ): super(TransformerDiffusionModel, self).__init__() assert posed_pc_emb_dim + pose_emb_dim == word_emb_dim assert encoder_input_dim == word_emb_dim + time_emb_dim + seq_pos_emb_dim + seq_type_emb_dim # 3D translation + 6D rotation action_dim = 3 + 6 # default: # 256 = 80 (point cloud) + 80 (position) + 80 (time) + 8 (position idx) + 8 (token idx) # 256 = 160 (word embedding) + 80 (time) + 8 (position idx) + 8 (token idx) # PC self.ignore_rgb = ignore_rgb if ignore_rgb: self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=3, mean_center=True) else: self.pc_encoder = PointTransformerEncoderSmall(output_dim=pc_emb_dim, input_dim=6, mean_center=True) self.posed_pc_encoder = EncoderMLP(pc_emb_dim, posed_pc_emb_dim, uses_pt=True) # for virtual structure frame self.use_virtual_structure_frame = use_virtual_structure_frame if use_virtual_structure_frame: self.virtual_frame_embed = nn.Parameter(torch.randn(1, 1, posed_pc_emb_dim)) # B, 1, posed_pc_emb_dim # for language self.sentence_embedding_dim = sentence_embedding_dim self.use_sentence_embedding = use_sentence_embedding if use_sentence_embedding: self.sentence_embedding_down_sample = torch.nn.Linear(sentence_embedding_dim, word_emb_dim) else: self.word_embeddings = torch.nn.Embedding(vocab_size, word_emb_dim, padding_idx=0) # for diffusion self.pose_encoder = nn.Sequential(nn.Linear(action_dim, pose_emb_dim)) self.time_embeddings = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.GELU(), nn.Linear(time_emb_dim, time_emb_dim), ) # for transformer self.position_embeddings = torch.nn.Embedding(max_seq_size, seq_pos_emb_dim) self.type_embeddings = torch.nn.Embedding(max_token_type_size, seq_type_emb_dim) encoder_layers = TransformerEncoderLayer(encoder_input_dim, num_attention_heads, encoder_hidden_dim, encoder_dropout, encoder_activation) self.encoder = TransformerEncoder(encoder_layers, encoder_num_layers) self.struct_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=structure_dropout) self.obj_head = DropoutSampler(encoder_input_dim, action_dim, dropout_rate=object_dropout) def encode_posed_pc(self, pcs, batch_size, num_objects): if self.ignore_rgb: center_xyz, x = self.pc_encoder(pcs[:, :, :3], None) else: center_xyz, x = self.pc_encoder(pcs[:, :, :3], pcs[:, :, 3:]) posed_pc_embed = self.posed_pc_encoder(x, center_xyz) posed_pc_embed = posed_pc_embed.reshape(batch_size, num_objects, -1) return posed_pc_embed def forward(self, t, pcs, sentence, poses, type_index, position_index, pad_mask): batch_size, num_objects, num_pts, _ = pcs.shape _, num_poses, _ = poses.shape if self.use_sentence_embedding: assert sentence.shape == (batch_size, self.sentence_embedding_dim), sentence.shape else: _, sentence_len = sentence.shape _, total_len = type_index.shape pcs = pcs.reshape(batch_size * num_objects, num_pts, -1) posed_pc_embed = self.encode_posed_pc(pcs, batch_size, num_objects) pose_embed = self.pose_encoder(poses) if self.use_virtual_structure_frame: virtual_frame_embed = self.virtual_frame_embed.repeat(batch_size, 1, 1) posed_pc_embed = torch.cat([virtual_frame_embed, posed_pc_embed], dim=1) tgt_obj_embed = torch.cat([pose_embed, posed_pc_embed], dim=-1) ######################### if self.use_sentence_embedding: # sentence: B, sentence_embedding_dim sentence_embed = self.sentence_embedding_down_sample(sentence).unsqueeze(1) # B, 1, word_emb_dim else: sentence_embed = self.word_embeddings(sentence) ######################### # transformer time dim: sentence, struct, obj # transformer feat dim: obj pc + pose / word, time, token type, position time_embed = self.time_embeddings(t) # B, dim time_embed = time_embed.unsqueeze(1).repeat(1, total_len, 1) # B, L, dim position_embed = self.position_embeddings(position_index) type_embed = self.type_embeddings(type_index) tgt_sequence_encode = torch.cat([sentence_embed, tgt_obj_embed], dim=1) tgt_sequence_encode = torch.cat([tgt_sequence_encode, time_embed, position_embed, type_embed], dim=-1) tgt_pad_mask = pad_mask ######################### # sequence_encode: [batch size, sequence_length, encoder input dimension] # input to transformer needs to have dimenion [sequence_length, batch size, encoder input dimension] tgt_sequence_encode = tgt_sequence_encode.transpose(1, 0) # convert to bool tgt_pad_mask = (tgt_pad_mask == 1) # encode: [sequence_length, batch_size, embedding_size] encode = self.encoder(tgt_sequence_encode, src_key_padding_mask=tgt_pad_mask) encode = encode.transpose(1, 0) ######################### target_encodes = encode[:, -num_poses:, :] if self.use_virtual_structure_frame: obj_encodes = target_encodes[:, 1:, :] pred_obj_poses = self.obj_head(obj_encodes) # B, N, 3 + 6 struct_encode = encode[:, 0, :].unsqueeze(1) # use a different sampler for struct prediction since it should have larger variance than object predictions pred_struct_pose = self.struct_head(struct_encode) # B, 1, 3 + 6 pred_poses = torch.cat([pred_struct_pose, pred_obj_poses], dim=1) else: pred_poses = self.obj_head(target_encodes) # B, N, 3 + 6 assert pred_poses.shape == poses.shape return pred_poses class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=.25): super(FocalLoss, self).__init__() # self.alpha = torch.tensor([alpha, 1-alpha]) self.gamma = gamma def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) # targets = targets.type(torch.long) # at = self.alpha.gather(0, targets.data.view(-1)) # F_loss = at*(1-pt)**self.gamma * BCE_loss F_loss = (1 - pt)**self.gamma * BCE_loss return F_loss.mean() class PCTDiscriminator(torch.nn.Module): def __init__(self, max_num_objects, include_env_pc=False, pct_random_sampling=False): super(PCTDiscriminator, self).__init__() # input_dim: xyz + one hot for each object if include_env_pc: self.classifier = PointTransformerCls(input_dim=max_num_objects + 1 + 3, output_dim=1, use_random_sampling=pct_random_sampling) else: self.classifier = PointTransformerCls(input_dim=max_num_objects + 3, output_dim=1, use_random_sampling=pct_random_sampling) def forward(self, scene_xyz): label = self.classifier(scene_xyz) return label def convert_logits(self, logits): return torch.sigmoid(logits)