File size: 3,283 Bytes
9ba7d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ef6db
 
9ba7d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from G2P_lexicon.config_models import config_sp
from G2P_lexicon.transformer import TransformerBlock
from G2P_lexicon.sp_tokenizer import Tokenizer_sp
import torch
import os

dirname = os.path.dirname(__file__)


class Stress_Pred:
    def __init__(self,

                 model,

                 tokenizer):

        self.SP = model
        self.tokenizer = tokenizer

        self.SP.eval()

    def __call__(self, srs):
        with torch.no_grad():
            enc_input_tokens = self.tokenizer.encode(srs)
            pad_id = torch.tensor(self.tokenizer.pad_idx)
            enc_num_padding_tokens = 32 - len(enc_input_tokens)
            encoder_input = torch.cat(
                [
                    torch.tensor(enc_input_tokens),
                    pad_id.repeat(enc_num_padding_tokens)
                ],
                dim=0)

            encoder_mask = (encoder_input != pad_id).unsqueeze(0).unsqueeze(0).int()
            label = self.greedy_decode_stress(
                src=encoder_input,
                src_mask=encoder_mask,
                start_token=self.tokenizer.sos_idx,
            )
        return label

    def greedy_decode_stress(self,

                             src,

                             src_mask,

                             start_token):
        len_src = (src != 3).int().sum().item()
        index_vowels = torch.tensor([(idx) for (idx, i) in enumerate(src) if not (i in list_tokens_without_stress)])[
                       :len_src]
        src = src.unsqueeze(0)
        src_mask = src_mask.unsqueeze(0)
        input_decoder = self.SP.encode(src, src_mask)
        label = torch.tensor([]).type_as(src.data)

        for idx in range(len_src):
            if idx in index_vowels:
                label = torch.cat([label, torch.ones(1, 1).type_as(src.data).fill_(src[0][idx])], dim=1)
            else:
                tgt_mask = (torch.tril(torch.ones((label.size(1), label.size(1)))).type_as(src.data)).unsqueeze(0)
                out = self.SP.decode(input_decoder, src_mask, label, tgt_mask)
                prob = self.SP.fc_out(out[:, -1])

                _, next_word = torch.max(prob, dim=1)
                next_word = next_word.data[0]
                label = torch.cat([label, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)

        pred = self.tokenizer.decode(label[0].tolist())[1:-1]
        return pred


dict_path = os.path.join(dirname, "my_tokenizer\sp_dict.json")
model_path = os.path.join(dirname, "models\model_sp.pt")

tokenizer_sp = Tokenizer_sp(dict_path=dict_path)

set_tokens_without_stress = set()
for token, phoneme in tokenizer_sp.idx2token.items():
    if phoneme[-1].isdigit():
        set_tokens_without_stress.add(tokenizer_sp.token2idx[phoneme[:-1]])
list_tokens_without_stress = list(set_tokens_without_stress)

sp_model = TransformerBlock(config=config_sp,
                            tokenizer=tokenizer_sp)
sp_model.load_state_dict(
    torch.load(model_path, map_location=torch.device('cpu')))

SP = Stress_Pred(model=sp_model,
                 tokenizer=tokenizer_sp)

if __name__ == '__main__':
    print(SP(['N', 'IH', 'K', 'IY', 'T', 'AH']))  #['N', 'IH2', 'K', 'IY1', 'T', 'AH0']