File size: 3,951 Bytes
b56c828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch import nn

from data import Tokenizer


class ResidualBlock(nn.Module):
    def __init__(self, num_channels, dropout=0.5):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(num_channels)
        self.conv2 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(num_channels)
        self.prelu = nn.PReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        residual = x
        x = self.prelu(self.bn1(self.conv1(x)))
        x = self.dropout(x)
        x = self.bn2(self.conv2(x))
        x = self.prelu(x)
        x = self.dropout(x)
        x += residual  # shouldn't it be after activation function?
        return x


class Seq2SeqCNN(nn.Module):
    # def __init__(self, dict_size_src, dict_size_trg, embedding_dim, num_channels, num_residual_blocks, dropout=0.5):
    def __init__(self, config):
        dict_size_src = config['dict_size_src']
        dict_size_trg = config['dict_size_trg']
        embedding_dim = config['embedding_dim']
        num_channels = config['num_channels']
        num_residual_blocks = config['num_residual_blocks']
        dropout = config['dropout']
        many_to_one = config['many_to_one']

        self.config = config
        
        super(Seq2SeqCNN, self).__init__()
        self.embedding = nn.Embedding(dict_size_src, embedding_dim)
        self.conv = nn.Conv1d(embedding_dim, num_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm1d(num_channels)

        self.residual_blocks = nn.Sequential(
            *(ResidualBlock(num_channels, dropout) for _ in range(num_residual_blocks))
            # Add as many blocks as required
        )
        self.fc = nn.Linear(num_channels, dict_size_trg*many_to_one)
        self.dropout = nn.Dropout(dropout)
        self.dict_size_trg = dict_size_trg
        
    def forward(self, src):
        # src: (batch_size, seq_len)
        batch_size = src.size(0)
        
        embedded = self.embedding(src).permute(0, 2, 1) # (bsize, emb_dim, seq_len)
        # print('embedded:', embedded.shape)
        conv_out0 = self.conv(embedded)  # (bsize, num_channels, seq_len)
        # print('conv_out0:', conv_out0.shape)
        # conv_out = embedded
        conv_out = self.dropout(torch.relu(self.bn(conv_out0)))
        # conv_out = conv_out0
        res_out = self.residual_blocks(conv_out)
        # print('res_out:', res_out.shape)
        res_out = res_out + conv_out
        # res_out = torch.cat([res_out, embedded], dim=1)
        out = self.fc(self.dropout(res_out.permute(0, 2, 1))) # permute back to original
        out = out.view(batch_size, -1, self.config['many_to_one'], self.dict_size_trg)
        return out


def init_model(path, device="cpu"):
    d = torch.load(path, map_location=device)
    state_dict = d['state_dict']
    model = Seq2SeqCNN(d['config']).to(device)
    model.load_state_dict(state_dict)
    return model



@torch.no_grad()
def _predict(model, src, device):
    
    model.eval()
    src = src.to(device)
    output = model(src)
    _, pred = torch.max(output, dim=-1)
    
    # output = torch.softmax(output, dim=3)
    # print(output.shape)
    # pred = torch.multinomial(output.view(-1, output.size(-1)), 1)
    # pred = pred.reshape(output.size()[:-1])
    # print(pred.shape)
    
    return pred


@torch.no_grad()
def predict(model, tokenizer: "Tokenizer", text:str, device):
    print('text:', text)
    if not text: return ''
    text_encoded = tokenizer.encode_src(text)

    batch = text_encoded.unsqueeze(0)
    
    prd = _predict(model, batch, device)[0]
    prd = prd[batch[0] != tokenizer.src_pad_idx,:]

    predicted_text = ''.join(tokenizer.decode_trg(prd))
    print('predicted_text:', repr(predicted_text))
    return predicted_text   # .replace('\u200c', '')