File size: 769 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import os
def decode_sequence(vocab, seq):
N, T = seq.size()
sents = []
for n in range(N):
words = []
for t in range(T):
ix = seq[n, t]
if ix == 0:
break
words.append(vocab[ix])
sent = ' '.join(words)
sents.append(sent)
return sents
def decode_sequence_bert(tokenizer, seq, sep_token_id):
N, T = seq.size()
seq = seq.data.cpu().numpy()
sents = []
for n in range(N):
words = []
for t in range(T):
ix = seq[n, t]
if ix == sep_token_id:
break
words.append(tokenizer.ids_to_tokens[ix])
sent = tokenizer.convert_tokens_to_string(words)
sents.append(sent)
return sents |