RuCommensClass / scr /model.py
ivanovot
update
d1474ea
raw
history blame contribute delete
968 Bytes
import torch
import torch.nn as nn
from .sbert import vectorize as vec
class Block(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.Dropout(0.2),
)
def forward(self, x):
return self.model(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
Block(1024, 512),
nn.LeakyReLU(),
Block(512, 256),
nn.LeakyReLU(),
Block(256, 128),
nn.LeakyReLU(),
Block(128, 64),
nn.LeakyReLU(),
Block(64, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
def predict(self, text):
return self(vec(text))