|
import torch |
|
from torch import nn |
|
|
|
|
|
class EmbeddingMLP(nn.Module): |
|
def __init__(self, size=4): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(768 * size, 900 * size), |
|
nn.BatchNorm1d(900 * size), |
|
nn.ReLU(), |
|
nn.Linear(900 * size, 300 * size) |
|
) |
|
|
|
def forward(self, data): |
|
res = self.net(data) |
|
return res |
|
|
|
|
|
class PairClassifier(nn.Module): |
|
def __init__(self, size=4): |
|
super().__init__() |
|
self.encoder = EmbeddingMLP(size) |
|
self.net = nn.Sequential( |
|
nn.Linear(300 * size * 2, 3000), |
|
nn.ReLU(), |
|
nn.Linear(3000, 1000), |
|
nn.ReLU(), |
|
nn.Linear(1000, 2), |
|
) |
|
|
|
def forward(self, data): |
|
e1 = self.encoder(data[:, :768 * 4]) |
|
e2 = self.encoder(data[:, 768 * 4:]) |
|
twins = torch.cat([e1, e2], dim=1) |
|
res = self.net(twins) |
|
return res |
|
|