sobir-hf's picture
First commit
b56c828
import torch
from torch import nn
from data import Tokenizer
class ResidualBlock(nn.Module):
def __init__(self, num_channels, dropout=0.5):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(num_channels)
self.conv2 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(num_channels)
self.prelu = nn.PReLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.prelu(self.bn1(self.conv1(x)))
x = self.dropout(x)
x = self.bn2(self.conv2(x))
x = self.prelu(x)
x = self.dropout(x)
x += residual # shouldn't it be after activation function?
return x
class Seq2SeqCNN(nn.Module):
# def __init__(self, dict_size_src, dict_size_trg, embedding_dim, num_channels, num_residual_blocks, dropout=0.5):
def __init__(self, config):
dict_size_src = config['dict_size_src']
dict_size_trg = config['dict_size_trg']
embedding_dim = config['embedding_dim']
num_channels = config['num_channels']
num_residual_blocks = config['num_residual_blocks']
dropout = config['dropout']
many_to_one = config['many_to_one']
self.config = config
super(Seq2SeqCNN, self).__init__()
self.embedding = nn.Embedding(dict_size_src, embedding_dim)
self.conv = nn.Conv1d(embedding_dim, num_channels, kernel_size=3, padding=1)
self.bn = nn.BatchNorm1d(num_channels)
self.residual_blocks = nn.Sequential(
*(ResidualBlock(num_channels, dropout) for _ in range(num_residual_blocks))
# Add as many blocks as required
)
self.fc = nn.Linear(num_channels, dict_size_trg*many_to_one)
self.dropout = nn.Dropout(dropout)
self.dict_size_trg = dict_size_trg
def forward(self, src):
# src: (batch_size, seq_len)
batch_size = src.size(0)
embedded = self.embedding(src).permute(0, 2, 1) # (bsize, emb_dim, seq_len)
# print('embedded:', embedded.shape)
conv_out0 = self.conv(embedded) # (bsize, num_channels, seq_len)
# print('conv_out0:', conv_out0.shape)
# conv_out = embedded
conv_out = self.dropout(torch.relu(self.bn(conv_out0)))
# conv_out = conv_out0
res_out = self.residual_blocks(conv_out)
# print('res_out:', res_out.shape)
res_out = res_out + conv_out
# res_out = torch.cat([res_out, embedded], dim=1)
out = self.fc(self.dropout(res_out.permute(0, 2, 1))) # permute back to original
out = out.view(batch_size, -1, self.config['many_to_one'], self.dict_size_trg)
return out
def init_model(path, device="cpu"):
d = torch.load(path, map_location=device)
state_dict = d['state_dict']
model = Seq2SeqCNN(d['config']).to(device)
model.load_state_dict(state_dict)
return model
@torch.no_grad()
def _predict(model, src, device):
model.eval()
src = src.to(device)
output = model(src)
_, pred = torch.max(output, dim=-1)
# output = torch.softmax(output, dim=3)
# print(output.shape)
# pred = torch.multinomial(output.view(-1, output.size(-1)), 1)
# pred = pred.reshape(output.size()[:-1])
# print(pred.shape)
return pred
@torch.no_grad()
def predict(model, tokenizer: "Tokenizer", text:str, device):
print('text:', text)
if not text: return ''
text_encoded = tokenizer.encode_src(text)
batch = text_encoded.unsqueeze(0)
prd = _predict(model, batch, device)[0]
prd = prd[batch[0] != tokenizer.src_pad_idx,:]
predicted_text = ''.join(tokenizer.decode_trg(prd))
print('predicted_text:', repr(predicted_text))
return predicted_text # .replace('\u200c', '')