File size: 954 Bytes
43515a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from torch import nn
class EmbeddingMLP(nn.Module):
def __init__(self, size=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768 * size, 900 * size),
nn.BatchNorm1d(900 * size),
nn.ReLU(),
nn.Linear(900 * size, 300 * size)
)
def forward(self, data):
res = self.net(data)
return res
class PairClassifier(nn.Module):
def __init__(self, size=4):
super().__init__()
self.encoder = EmbeddingMLP(size)
self.net = nn.Sequential(
nn.Linear(300 * size * 2, 3000),
nn.ReLU(),
nn.Linear(3000, 1000),
nn.ReLU(),
nn.Linear(1000, 2),
)
def forward(self, data):
e1 = self.encoder(data[:, :768 * 4])
e2 = self.encoder(data[:, 768 * 4:])
twins = torch.cat([e1, e2], dim=1)
res = self.net(twins)
return res
|