File size: 769 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import os

def decode_sequence(vocab, seq):
    N, T = seq.size()
    sents = []
    for n in range(N):
        words = []
        for t in range(T):
            ix = seq[n, t]
            if ix == 0:
                break
            words.append(vocab[ix])
        sent = ' '.join(words)
        sents.append(sent)
    return sents

def decode_sequence_bert(tokenizer, seq, sep_token_id):
    N, T = seq.size()
    seq = seq.data.cpu().numpy()
    sents = []
    for n in range(N):
        words = []
        for t in range(T):
            ix = seq[n, t]
            if ix == sep_token_id:
                break
            words.append(tokenizer.ids_to_tokens[ix])
        sent = tokenizer.convert_tokens_to_string(words)
        sents.append(sent)
    return sents