|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
class ShapeAttrEmbedding(nn.Module): |
|
|
|
def __init__(self, dim, out_dim, cls_num_list): |
|
super(ShapeAttrEmbedding, self).__init__() |
|
|
|
for idx, cls_num in enumerate(cls_num_list): |
|
setattr( |
|
self, f'attr_{idx}', |
|
nn.Sequential( |
|
nn.Linear(cls_num, dim), nn.LeakyReLU(), |
|
nn.Linear(dim, dim))) |
|
self.cls_num_list = cls_num_list |
|
self.attr_num = len(cls_num_list) |
|
self.fusion = nn.Sequential( |
|
nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(), |
|
nn.Linear(out_dim, out_dim)) |
|
|
|
def forward(self, attr): |
|
attr_embedding_list = [] |
|
for idx in range(self.attr_num): |
|
attr_embed_fc = getattr(self, f'attr_{idx}') |
|
attr_embedding_list.append( |
|
attr_embed_fc( |
|
F.one_hot( |
|
attr[:, idx], |
|
num_classes=self.cls_num_list[idx]).to(torch.float32))) |
|
attr_embedding = torch.cat(attr_embedding_list, dim=1) |
|
attr_embedding = self.fusion(attr_embedding) |
|
|
|
return attr_embedding |
|
|