import torch import torch.nn as nn from torch.nn import functional as F class SemanticEmbedding(nn.Module): def __init__(self, args, mesh_dim=71, report_dim=761, embed_size=512): super(SemanticEmbedding, self).__init__() self.mesh_tf = nn.Sequential( nn.Linear(embed_size, embed_size // 2), nn.ReLU(inplace=True), nn.Linear(embed_size // 2, embed_size // 4), nn.ReLU(inplace=True), nn.Linear(embed_size // 4, mesh_dim) ) self.report_tf = nn.Sequential( nn.Linear(embed_size, embed_size // 2), nn.ReLU(inplace=True), nn.Linear(embed_size // 2, embed_size // 4), nn.ReLU(inplace=True), nn.Linear(embed_size // 4, report_dim) ) self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.1) self.w1 = nn.Linear(in_features=mesh_dim + report_dim, out_features=embed_size) self.w2 = nn.Linear(in_features=embed_size, out_features=embed_size) self.relu = nn.ReLU() self.logit = nn.Linear(60, 31) self.dropout = nn.Dropout(0.2) self.__init_weight() self.target_dim = 60 self.sigm = nn.Sigmoid() def __init_weight(self): self.w1.weight.data.uniform_(-0.1, 0.1) self.w1.bias.data.fill_(0) self.w2.weight.data.uniform_(-0.1, 0.1) self.w2.bias.data.fill_(0) def forward(self, avg, pred_output): avg_visual = avg.unsqueeze(1) pred_output2 = F.pad(pred_output, (0, 0, 0, self.target_dim - pred_output.shape[1]), 'constant', 0) pred = pred_output2.permute(0, 2, 1) visual_text = torch.matmul(avg_visual, pred).squeeze(1) outputs = self.sigm(self.logit(visual_text)) return outputs class classfication(nn.Module): def __init__(self, distiller_num, avg_dim=1024): super(classfication, self).__init__() self.logit = nn.Linear(avg_dim, distiller_num) self.relu = nn.ReLU() self.sigm = nn.Sigmoid() self.dropout = nn.Dropout(0.5) def forward(self, avg): avg_visual = self.dropout(avg) x = self.logit(avg_visual) outputs = self.sigm(x) return outputs