|
|
|
import os |
|
|
|
def decode_sequence(vocab, seq): |
|
N, T = seq.size() |
|
sents = [] |
|
for n in range(N): |
|
words = [] |
|
for t in range(T): |
|
ix = seq[n, t] |
|
if ix == 0: |
|
break |
|
words.append(vocab[ix]) |
|
sent = ' '.join(words) |
|
sents.append(sent) |
|
return sents |
|
|
|
def decode_sequence_bert(tokenizer, seq, sep_token_id): |
|
N, T = seq.size() |
|
seq = seq.data.cpu().numpy() |
|
sents = [] |
|
for n in range(N): |
|
words = [] |
|
for t in range(T): |
|
ix = seq[n, t] |
|
if ix == sep_token_id: |
|
break |
|
words.append(tokenizer.ids_to_tokens[ix]) |
|
sent = tokenizer.convert_tokens_to_string(words) |
|
sents.append(sent) |
|
return sents |