File size: 7,832 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
# -*- coding: utf-8 -*-
import torch
from itertools import accumulate
from onmt.constants import SubwordMarker
def make_batch_align_matrix(index_tensor, size=None, normalize=False):
"""
Convert a sparse index_tensor into a batch of alignment matrix,
with row normalize to the sum of 1 if set normalize.
Args:
index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id]
size (List[int]): Size of the sparse tensor.
normalize (bool): if normalize the 2nd dim of resulting tensor.
"""
n_fill, device = index_tensor.size(0), index_tensor.device
value_tensor = torch.ones([n_fill], dtype=torch.float)
dense_tensor = torch.sparse_coo_tensor(
index_tensor.t(), value_tensor, size=size, device=device).to_dense()
if normalize:
row_sum = dense_tensor.sum(-1, keepdim=True) # sum by row(tgt)
# threshold on 1 to avoid div by 0
torch.nn.functional.threshold(row_sum, 1, 1, inplace=True)
dense_tensor.div_(row_sum)
return dense_tensor
def extract_alignment(align_matrix, tgt_mask, src_lens, n_best):
"""
Extract a batched align_matrix into its src indice alignment lists,
with tgt_mask to filter out invalid tgt position as EOS/PAD.
BOS already excluded from tgt_mask in order to match prediction.
Args:
align_matrix (Tensor): ``(B, tgt_len, src_len)``,
attention head normalized by Softmax(dim=-1)
tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD.
src_lens (LongTensor): ``(B,)``, containing valid src length
n_best (int): a value indicating number of parallel translation.
* B: denote flattened batch as B = batch_size * n_best.
Returns:
alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``,
containing valid alignment matrix (or None if blank prediction)
for each translation.
"""
batch_size_n_best = align_matrix.size(0)
assert batch_size_n_best % n_best == 0
alignments = [[] for _ in range(batch_size_n_best // n_best)]
# treat alignment matrix one by one as each have different lengths
for i, (am_b, tgt_mask_b, src_len) in enumerate(
zip(align_matrix, tgt_mask, src_lens)):
valid_tgt = ~tgt_mask_b
valid_tgt_len = valid_tgt.sum()
if valid_tgt_len == 0:
# No alignment if not exist valid tgt token
valid_alignment = None
else:
# get valid alignment (sub-matrix from full paded aligment matrix)
am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \
.view(valid_tgt_len, -1)
valid_alignment = am_valid_tgt[:, :src_len] # only keep valid src
alignments[i // n_best].append(valid_alignment)
return alignments
def build_align_pharaoh(valid_alignment):
"""Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs,
or empty list if it's None.
"""
align_pairs = []
if isinstance(valid_alignment, torch.Tensor):
tgt_align_src_id = valid_alignment.argmax(dim=-1)
for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()):
align_pairs.append(str(src_id) + "-" + str(tgt_id))
align_pairs.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
align_pairs.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
return align_pairs
def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'):
"""Convert subword alignment to word alignment.
Args:
src (string): tokenized sentence in source language.
tgt (string): tokenized sentence in target language.
subword_align (string): align_pharaoh correspond to src-tgt.
m_src (string): tokenization mode used in src,
can be ["joiner", "spacer"].
m_tgt (string): tokenization mode used in tgt,
can be ["joiner", "spacer"].
Returns:
word_align (string): converted alignments correspand to
detokenized src-tgt.
"""
assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!"
assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!"
src, tgt = src.strip().split(), tgt.strip().split()
subword_align = {(int(a), int(b)) for a, b in (x.split("-")
for x in subword_align.split())}
src_map = (subword_map_by_spacer(src) if m_src == 'spacer'
else subword_map_by_joiner(src))
tgt_map = (subword_map_by_spacer(src) if m_tgt == 'spacer'
else subword_map_by_joiner(src))
word_align = list({"{}-{}".format(src_map[a], tgt_map[b])
for a, b in subword_align})
word_align.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
word_align.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
return " ".join(word_align)
# Helper functions
def begin_uppercase(token):
return token == SubwordMarker.BEGIN_UPPERCASE
def end_uppercase(token):
return token == SubwordMarker.END_UPPERCASE
def begin_case(token):
return token == SubwordMarker.BEGIN_CASED
def case_markup(token):
return begin_uppercase(token) \
or end_uppercase(token) \
or begin_case(token)
def subword_map_by_joiner(subwords,
original_subwords=None,
marker=SubwordMarker.JOINER):
"""Return word id for each subword token (annotate by joiner)."""
flags = [1] * len(subwords)
j = 0
finished = True
for i, tok in enumerate(subwords):
previous_tok = subwords[i-1] if i else "" # Previous N-1 token
previous_tok_2 = subwords[i-2] if i > 1 else "" # Previous N-2 token
# Keeps track of the original words/subwords
# ('prior_tokenization' option)
current_original_subword = "" if not original_subwords \
else original_subwords[j] if j < len(original_subwords) else ""
if tok.startswith(marker) and tok != current_original_subword:
flags[i] = 0
elif (previous_tok.endswith(marker)
or begin_case(previous_tok)
or begin_uppercase(previous_tok)) \
and not finished:
flags[i] = 0
elif previous_tok_2.endswith(marker) \
and case_markup(previous_tok) \
and not finished:
flags[i] = 0
elif end_uppercase(tok) and tok != current_original_subword:
flags[i] = 0
else:
finished = False
if tok == current_original_subword:
finished = True
j += 1
flags[0] = 0
word_group = list(accumulate(flags))
if original_subwords:
assert max(word_group) < len(original_subwords)
return word_group
def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER):
"""Return word id for each subword token (annotate by spacer)."""
flags = [0] * len(subwords)
for i, tok in enumerate(subwords):
if marker in tok:
if case_markup(tok.replace(marker, "")):
if i < len(subwords)-1:
flags[i] = 1
else:
if i > 0:
previous = subwords[i-1].replace(marker, "")
if not case_markup(previous):
flags[i] = 1
# In case there is a final case_markup when new_spacer is on
for i in range(1, len(subwords)-1):
if case_markup(subwords[-i]):
flags[-i] = 0
elif subwords[-i] == marker:
flags[-i] = 0
break
word_group = list(accumulate(flags))
if word_group[0] == 1: # when dummy prefix is set
word_group = [item - 1 for item in word_group]
return word_group
|