sakharamg's picture
Uploading all files
158b61b
# -*- coding: utf-8 -*-
import torch
from itertools import accumulate
from onmt.constants import SubwordMarker
def make_batch_align_matrix(index_tensor, size=None, normalize=False):
"""
Convert a sparse index_tensor into a batch of alignment matrix,
with row normalize to the sum of 1 if set normalize.
Args:
index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id]
size (List[int]): Size of the sparse tensor.
normalize (bool): if normalize the 2nd dim of resulting tensor.
"""
n_fill, device = index_tensor.size(0), index_tensor.device
value_tensor = torch.ones([n_fill], dtype=torch.float)
dense_tensor = torch.sparse_coo_tensor(
index_tensor.t(), value_tensor, size=size, device=device).to_dense()
if normalize:
row_sum = dense_tensor.sum(-1, keepdim=True) # sum by row(tgt)
# threshold on 1 to avoid div by 0
torch.nn.functional.threshold(row_sum, 1, 1, inplace=True)
dense_tensor.div_(row_sum)
return dense_tensor
def extract_alignment(align_matrix, tgt_mask, src_lens, n_best):
"""
Extract a batched align_matrix into its src indice alignment lists,
with tgt_mask to filter out invalid tgt position as EOS/PAD.
BOS already excluded from tgt_mask in order to match prediction.
Args:
align_matrix (Tensor): ``(B, tgt_len, src_len)``,
attention head normalized by Softmax(dim=-1)
tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD.
src_lens (LongTensor): ``(B,)``, containing valid src length
n_best (int): a value indicating number of parallel translation.
* B: denote flattened batch as B = batch_size * n_best.
Returns:
alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``,
containing valid alignment matrix (or None if blank prediction)
for each translation.
"""
batch_size_n_best = align_matrix.size(0)
assert batch_size_n_best % n_best == 0
alignments = [[] for _ in range(batch_size_n_best // n_best)]
# treat alignment matrix one by one as each have different lengths
for i, (am_b, tgt_mask_b, src_len) in enumerate(
zip(align_matrix, tgt_mask, src_lens)):
valid_tgt = ~tgt_mask_b
valid_tgt_len = valid_tgt.sum()
if valid_tgt_len == 0:
# No alignment if not exist valid tgt token
valid_alignment = None
else:
# get valid alignment (sub-matrix from full paded aligment matrix)
am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \
.view(valid_tgt_len, -1)
valid_alignment = am_valid_tgt[:, :src_len] # only keep valid src
alignments[i // n_best].append(valid_alignment)
return alignments
def build_align_pharaoh(valid_alignment):
"""Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs,
or empty list if it's None.
"""
align_pairs = []
if isinstance(valid_alignment, torch.Tensor):
tgt_align_src_id = valid_alignment.argmax(dim=-1)
for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()):
align_pairs.append(str(src_id) + "-" + str(tgt_id))
align_pairs.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
align_pairs.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
return align_pairs
def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'):
"""Convert subword alignment to word alignment.
Args:
src (string): tokenized sentence in source language.
tgt (string): tokenized sentence in target language.
subword_align (string): align_pharaoh correspond to src-tgt.
m_src (string): tokenization mode used in src,
can be ["joiner", "spacer"].
m_tgt (string): tokenization mode used in tgt,
can be ["joiner", "spacer"].
Returns:
word_align (string): converted alignments correspand to
detokenized src-tgt.
"""
assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!"
assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!"
src, tgt = src.strip().split(), tgt.strip().split()
subword_align = {(int(a), int(b)) for a, b in (x.split("-")
for x in subword_align.split())}
src_map = (subword_map_by_spacer(src) if m_src == 'spacer'
else subword_map_by_joiner(src))
tgt_map = (subword_map_by_spacer(src) if m_tgt == 'spacer'
else subword_map_by_joiner(src))
word_align = list({"{}-{}".format(src_map[a], tgt_map[b])
for a, b in subword_align})
word_align.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
word_align.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
return " ".join(word_align)
# Helper functions
def begin_uppercase(token):
return token == SubwordMarker.BEGIN_UPPERCASE
def end_uppercase(token):
return token == SubwordMarker.END_UPPERCASE
def begin_case(token):
return token == SubwordMarker.BEGIN_CASED
def case_markup(token):
return begin_uppercase(token) \
or end_uppercase(token) \
or begin_case(token)
def subword_map_by_joiner(subwords,
original_subwords=None,
marker=SubwordMarker.JOINER):
"""Return word id for each subword token (annotate by joiner)."""
flags = [1] * len(subwords)
j = 0
finished = True
for i, tok in enumerate(subwords):
previous_tok = subwords[i-1] if i else "" # Previous N-1 token
previous_tok_2 = subwords[i-2] if i > 1 else "" # Previous N-2 token
# Keeps track of the original words/subwords
# ('prior_tokenization' option)
current_original_subword = "" if not original_subwords \
else original_subwords[j] if j < len(original_subwords) else ""
if tok.startswith(marker) and tok != current_original_subword:
flags[i] = 0
elif (previous_tok.endswith(marker)
or begin_case(previous_tok)
or begin_uppercase(previous_tok)) \
and not finished:
flags[i] = 0
elif previous_tok_2.endswith(marker) \
and case_markup(previous_tok) \
and not finished:
flags[i] = 0
elif end_uppercase(tok) and tok != current_original_subword:
flags[i] = 0
else:
finished = False
if tok == current_original_subword:
finished = True
j += 1
flags[0] = 0
word_group = list(accumulate(flags))
if original_subwords:
assert max(word_group) < len(original_subwords)
return word_group
def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER):
"""Return word id for each subword token (annotate by spacer)."""
flags = [0] * len(subwords)
for i, tok in enumerate(subwords):
if marker in tok:
if case_markup(tok.replace(marker, "")):
if i < len(subwords)-1:
flags[i] = 1
else:
if i > 0:
previous = subwords[i-1].replace(marker, "")
if not case_markup(previous):
flags[i] = 1
# In case there is a final case_markup when new_spacer is on
for i in range(1, len(subwords)-1):
if case_markup(subwords[-i]):
flags[-i] = 0
elif subwords[-i] == marker:
flags[-i] = 0
break
word_group = list(accumulate(flags))
if word_group[0] == 1: # when dummy prefix is set
word_group = [item - 1 for item in word_group]
return word_group