|
|
|
|
|
import torch |
|
from itertools import accumulate |
|
from onmt.constants import SubwordMarker |
|
|
|
|
|
def make_batch_align_matrix(index_tensor, size=None, normalize=False): |
|
""" |
|
Convert a sparse index_tensor into a batch of alignment matrix, |
|
with row normalize to the sum of 1 if set normalize. |
|
|
|
Args: |
|
index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id] |
|
size (List[int]): Size of the sparse tensor. |
|
normalize (bool): if normalize the 2nd dim of resulting tensor. |
|
""" |
|
n_fill, device = index_tensor.size(0), index_tensor.device |
|
value_tensor = torch.ones([n_fill], dtype=torch.float) |
|
dense_tensor = torch.sparse_coo_tensor( |
|
index_tensor.t(), value_tensor, size=size, device=device).to_dense() |
|
if normalize: |
|
row_sum = dense_tensor.sum(-1, keepdim=True) |
|
|
|
torch.nn.functional.threshold(row_sum, 1, 1, inplace=True) |
|
dense_tensor.div_(row_sum) |
|
return dense_tensor |
|
|
|
|
|
def extract_alignment(align_matrix, tgt_mask, src_lens, n_best): |
|
""" |
|
Extract a batched align_matrix into its src indice alignment lists, |
|
with tgt_mask to filter out invalid tgt position as EOS/PAD. |
|
BOS already excluded from tgt_mask in order to match prediction. |
|
|
|
Args: |
|
align_matrix (Tensor): ``(B, tgt_len, src_len)``, |
|
attention head normalized by Softmax(dim=-1) |
|
tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD. |
|
src_lens (LongTensor): ``(B,)``, containing valid src length |
|
n_best (int): a value indicating number of parallel translation. |
|
* B: denote flattened batch as B = batch_size * n_best. |
|
|
|
Returns: |
|
alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``, |
|
containing valid alignment matrix (or None if blank prediction) |
|
for each translation. |
|
""" |
|
batch_size_n_best = align_matrix.size(0) |
|
assert batch_size_n_best % n_best == 0 |
|
|
|
alignments = [[] for _ in range(batch_size_n_best // n_best)] |
|
|
|
|
|
for i, (am_b, tgt_mask_b, src_len) in enumerate( |
|
zip(align_matrix, tgt_mask, src_lens)): |
|
valid_tgt = ~tgt_mask_b |
|
valid_tgt_len = valid_tgt.sum() |
|
if valid_tgt_len == 0: |
|
|
|
valid_alignment = None |
|
else: |
|
|
|
am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \ |
|
.view(valid_tgt_len, -1) |
|
valid_alignment = am_valid_tgt[:, :src_len] |
|
alignments[i // n_best].append(valid_alignment) |
|
|
|
return alignments |
|
|
|
|
|
def build_align_pharaoh(valid_alignment): |
|
"""Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs, |
|
or empty list if it's None. |
|
""" |
|
align_pairs = [] |
|
if isinstance(valid_alignment, torch.Tensor): |
|
tgt_align_src_id = valid_alignment.argmax(dim=-1) |
|
|
|
for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()): |
|
align_pairs.append(str(src_id) + "-" + str(tgt_id)) |
|
align_pairs.sort(key=lambda x: int(x.split('-')[-1])) |
|
align_pairs.sort(key=lambda x: int(x.split('-')[0])) |
|
return align_pairs |
|
|
|
|
|
def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'): |
|
"""Convert subword alignment to word alignment. |
|
|
|
Args: |
|
src (string): tokenized sentence in source language. |
|
tgt (string): tokenized sentence in target language. |
|
subword_align (string): align_pharaoh correspond to src-tgt. |
|
m_src (string): tokenization mode used in src, |
|
can be ["joiner", "spacer"]. |
|
m_tgt (string): tokenization mode used in tgt, |
|
can be ["joiner", "spacer"]. |
|
|
|
Returns: |
|
word_align (string): converted alignments correspand to |
|
detokenized src-tgt. |
|
""" |
|
assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!" |
|
assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!" |
|
|
|
src, tgt = src.strip().split(), tgt.strip().split() |
|
subword_align = {(int(a), int(b)) for a, b in (x.split("-") |
|
for x in subword_align.split())} |
|
|
|
src_map = (subword_map_by_spacer(src) if m_src == 'spacer' |
|
else subword_map_by_joiner(src)) |
|
|
|
tgt_map = (subword_map_by_spacer(src) if m_tgt == 'spacer' |
|
else subword_map_by_joiner(src)) |
|
|
|
word_align = list({"{}-{}".format(src_map[a], tgt_map[b]) |
|
for a, b in subword_align}) |
|
word_align.sort(key=lambda x: int(x.split('-')[-1])) |
|
word_align.sort(key=lambda x: int(x.split('-')[0])) |
|
return " ".join(word_align) |
|
|
|
|
|
|
|
def begin_uppercase(token): |
|
return token == SubwordMarker.BEGIN_UPPERCASE |
|
|
|
|
|
def end_uppercase(token): |
|
return token == SubwordMarker.END_UPPERCASE |
|
|
|
|
|
def begin_case(token): |
|
return token == SubwordMarker.BEGIN_CASED |
|
|
|
|
|
def case_markup(token): |
|
return begin_uppercase(token) \ |
|
or end_uppercase(token) \ |
|
or begin_case(token) |
|
|
|
|
|
def subword_map_by_joiner(subwords, |
|
original_subwords=None, |
|
marker=SubwordMarker.JOINER): |
|
"""Return word id for each subword token (annotate by joiner).""" |
|
|
|
flags = [1] * len(subwords) |
|
j = 0 |
|
finished = True |
|
for i, tok in enumerate(subwords): |
|
|
|
previous_tok = subwords[i-1] if i else "" |
|
previous_tok_2 = subwords[i-2] if i > 1 else "" |
|
|
|
|
|
current_original_subword = "" if not original_subwords \ |
|
else original_subwords[j] if j < len(original_subwords) else "" |
|
|
|
if tok.startswith(marker) and tok != current_original_subword: |
|
flags[i] = 0 |
|
elif (previous_tok.endswith(marker) |
|
or begin_case(previous_tok) |
|
or begin_uppercase(previous_tok)) \ |
|
and not finished: |
|
flags[i] = 0 |
|
elif previous_tok_2.endswith(marker) \ |
|
and case_markup(previous_tok) \ |
|
and not finished: |
|
flags[i] = 0 |
|
elif end_uppercase(tok) and tok != current_original_subword: |
|
flags[i] = 0 |
|
else: |
|
finished = False |
|
if tok == current_original_subword: |
|
finished = True |
|
j += 1 |
|
|
|
flags[0] = 0 |
|
word_group = list(accumulate(flags)) |
|
|
|
if original_subwords: |
|
assert max(word_group) < len(original_subwords) |
|
return word_group |
|
|
|
|
|
def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER): |
|
"""Return word id for each subword token (annotate by spacer).""" |
|
flags = [0] * len(subwords) |
|
for i, tok in enumerate(subwords): |
|
if marker in tok: |
|
if case_markup(tok.replace(marker, "")): |
|
if i < len(subwords)-1: |
|
flags[i] = 1 |
|
else: |
|
if i > 0: |
|
previous = subwords[i-1].replace(marker, "") |
|
if not case_markup(previous): |
|
flags[i] = 1 |
|
|
|
|
|
for i in range(1, len(subwords)-1): |
|
if case_markup(subwords[-i]): |
|
flags[-i] = 0 |
|
elif subwords[-i] == marker: |
|
flags[-i] = 0 |
|
break |
|
|
|
word_group = list(accumulate(flags)) |
|
if word_group[0] == 1: |
|
word_group = [item - 1 for item in word_group] |
|
return word_group |
|
|