Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from data import Tokenizer | |
class ResidualBlock(nn.Module): | |
def __init__(self, num_channels, dropout=0.5): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm1d(num_channels) | |
self.conv2 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm1d(num_channels) | |
self.prelu = nn.PReLU() | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
residual = x | |
x = self.prelu(self.bn1(self.conv1(x))) | |
x = self.dropout(x) | |
x = self.bn2(self.conv2(x)) | |
x = self.prelu(x) | |
x = self.dropout(x) | |
x += residual # shouldn't it be after activation function? | |
return x | |
class Seq2SeqCNN(nn.Module): | |
# def __init__(self, dict_size_src, dict_size_trg, embedding_dim, num_channels, num_residual_blocks, dropout=0.5): | |
def __init__(self, config): | |
dict_size_src = config['dict_size_src'] | |
dict_size_trg = config['dict_size_trg'] | |
embedding_dim = config['embedding_dim'] | |
num_channels = config['num_channels'] | |
num_residual_blocks = config['num_residual_blocks'] | |
dropout = config['dropout'] | |
many_to_one = config['many_to_one'] | |
self.config = config | |
super(Seq2SeqCNN, self).__init__() | |
self.embedding = nn.Embedding(dict_size_src, embedding_dim) | |
self.conv = nn.Conv1d(embedding_dim, num_channels, kernel_size=3, padding=1) | |
self.bn = nn.BatchNorm1d(num_channels) | |
self.residual_blocks = nn.Sequential( | |
*(ResidualBlock(num_channels, dropout) for _ in range(num_residual_blocks)) | |
# Add as many blocks as required | |
) | |
self.fc = nn.Linear(num_channels, dict_size_trg*many_to_one) | |
self.dropout = nn.Dropout(dropout) | |
self.dict_size_trg = dict_size_trg | |
def forward(self, src): | |
# src: (batch_size, seq_len) | |
batch_size = src.size(0) | |
embedded = self.embedding(src).permute(0, 2, 1) # (bsize, emb_dim, seq_len) | |
# print('embedded:', embedded.shape) | |
conv_out0 = self.conv(embedded) # (bsize, num_channels, seq_len) | |
# print('conv_out0:', conv_out0.shape) | |
# conv_out = embedded | |
conv_out = self.dropout(torch.relu(self.bn(conv_out0))) | |
# conv_out = conv_out0 | |
res_out = self.residual_blocks(conv_out) | |
# print('res_out:', res_out.shape) | |
res_out = res_out + conv_out | |
# res_out = torch.cat([res_out, embedded], dim=1) | |
out = self.fc(self.dropout(res_out.permute(0, 2, 1))) # permute back to original | |
out = out.view(batch_size, -1, self.config['many_to_one'], self.dict_size_trg) | |
return out | |
def init_model(path, device="cpu"): | |
d = torch.load(path, map_location=device) | |
state_dict = d['state_dict'] | |
model = Seq2SeqCNN(d['config']).to(device) | |
model.load_state_dict(state_dict) | |
return model | |
def _predict(model, src, device): | |
model.eval() | |
src = src.to(device) | |
output = model(src) | |
_, pred = torch.max(output, dim=-1) | |
# output = torch.softmax(output, dim=3) | |
# print(output.shape) | |
# pred = torch.multinomial(output.view(-1, output.size(-1)), 1) | |
# pred = pred.reshape(output.size()[:-1]) | |
# print(pred.shape) | |
return pred | |
def predict(model, tokenizer: "Tokenizer", text:str, device): | |
print('text:', text) | |
if not text: return '' | |
text_encoded = tokenizer.encode_src(text) | |
batch = text_encoded.unsqueeze(0) | |
prd = _predict(model, batch, device)[0] | |
prd = prd[batch[0] != tokenizer.src_pad_idx,:] | |
predicted_text = ''.join(tokenizer.decode_trg(prd)) | |
print('predicted_text:', repr(predicted_text)) | |
return predicted_text # .replace('\u200c', '') | |