SIAT-RZJS's picture
Upload 187 files
3b2b066 verified
import torch
import torch.nn as nn
from torch.nn import functional as F
class SemanticEmbedding(nn.Module):
def __init__(self, args, mesh_dim=71, report_dim=761, embed_size=512):
super(SemanticEmbedding, self).__init__()
self.mesh_tf = nn.Sequential(
nn.Linear(embed_size, embed_size // 2),
nn.ReLU(inplace=True),
nn.Linear(embed_size // 2, embed_size // 4),
nn.ReLU(inplace=True),
nn.Linear(embed_size // 4, mesh_dim)
)
self.report_tf = nn.Sequential(
nn.Linear(embed_size, embed_size // 2),
nn.ReLU(inplace=True),
nn.Linear(embed_size // 2, embed_size // 4),
nn.ReLU(inplace=True),
nn.Linear(embed_size // 4, report_dim)
)
self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.1)
self.w1 = nn.Linear(in_features=mesh_dim + report_dim, out_features=embed_size)
self.w2 = nn.Linear(in_features=embed_size, out_features=embed_size)
self.relu = nn.ReLU()
self.logit = nn.Linear(60, 31)
self.dropout = nn.Dropout(0.2)
self.__init_weight()
self.target_dim = 60
self.sigm = nn.Sigmoid()
def __init_weight(self):
self.w1.weight.data.uniform_(-0.1, 0.1)
self.w1.bias.data.fill_(0)
self.w2.weight.data.uniform_(-0.1, 0.1)
self.w2.bias.data.fill_(0)
def forward(self, avg, pred_output):
avg_visual = avg.unsqueeze(1)
pred_output2 = F.pad(pred_output, (0, 0, 0, self.target_dim - pred_output.shape[1]), 'constant', 0)
pred = pred_output2.permute(0, 2, 1)
visual_text = torch.matmul(avg_visual, pred).squeeze(1)
outputs = self.sigm(self.logit(visual_text))
return outputs
class classfication(nn.Module):
def __init__(self, distiller_num, avg_dim=1024):
super(classfication, self).__init__()
self.logit = nn.Linear(avg_dim, distiller_num)
self.relu = nn.ReLU()
self.sigm = nn.Sigmoid()
self.dropout = nn.Dropout(0.5)
def forward(self, avg):
avg_visual = self.dropout(avg)
x = self.logit(avg_visual)
outputs = self.sigm(x)
return outputs