herrius's picture
Upload 259 files
32b542e
raw
history blame contribute delete
769 Bytes
import os
def decode_sequence(vocab, seq):
N, T = seq.size()
sents = []
for n in range(N):
words = []
for t in range(T):
ix = seq[n, t]
if ix == 0:
break
words.append(vocab[ix])
sent = ' '.join(words)
sents.append(sent)
return sents
def decode_sequence_bert(tokenizer, seq, sep_token_id):
N, T = seq.size()
seq = seq.data.cpu().numpy()
sents = []
for n in range(N):
words = []
for t in range(T):
ix = seq[n, t]
if ix == sep_token_id:
break
words.append(tokenizer.ids_to_tokens[ix])
sent = tokenizer.convert_tokens_to_string(words)
sents.append(sent)
return sents