RepoSnipy / common /pair_classifier.py
HenryStephen's picture
topic cluster and code cluster
c6a1f8c
raw
history blame
954 Bytes
import torch
from torch import nn
class EmbeddingMLP(nn.Module):
def __init__(self, size=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768 * size, 900 * size),
nn.BatchNorm1d(900 * size),
nn.ReLU(),
nn.Linear(900 * size, 300 * size)
)
def forward(self, data):
res = self.net(data)
return res
class PairClassifier(nn.Module):
def __init__(self, size=4):
super().__init__()
self.encoder = EmbeddingMLP(size)
self.net = nn.Sequential(
nn.Linear(300 * size * 2, 3000),
nn.ReLU(),
nn.Linear(3000, 1000),
nn.ReLU(),
nn.Linear(1000, 2),
)
def forward(self, data):
e1 = self.encoder(data[:, :768 * 4])
e2 = self.encoder(data[:, 768 * 4:])
twins = torch.cat([e1, e2], dim=1)
res = self.net(twins)
return res