import torch from torch import nn class EmbeddingMLP(nn.Module): def __init__(self, size=4): super().__init__() self.net = nn.Sequential( nn.Linear(768 * size, 900 * size), nn.BatchNorm1d(900 * size), nn.ReLU(), nn.Linear(900 * size, 300 * size) ) def forward(self, data): res = self.net(data) return res class PairClassifier(nn.Module): def __init__(self, size=4): super().__init__() self.encoder = EmbeddingMLP(size) self.net = nn.Sequential( nn.Linear(300 * size * 2, 3000), nn.ReLU(), nn.Linear(3000, 1000), nn.ReLU(), nn.Linear(1000, 2), ) def forward(self, data): e1 = self.encoder(data[:, :768 * 4]) e2 = self.encoder(data[:, 768 * 4:]) twins = torch.cat([e1, e2], dim=1) res = self.net(twins) return res