File size: 7,832 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# -*- coding: utf-8 -*-

import torch
from itertools import accumulate
from onmt.constants import SubwordMarker


def make_batch_align_matrix(index_tensor, size=None, normalize=False):
    """
    Convert a sparse index_tensor into a batch of alignment matrix,
    with row normalize to the sum of 1 if set normalize.

    Args:
        index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id]
        size (List[int]): Size of the sparse tensor.
        normalize (bool): if normalize the 2nd dim of resulting tensor.
    """
    n_fill, device = index_tensor.size(0), index_tensor.device
    value_tensor = torch.ones([n_fill], dtype=torch.float)
    dense_tensor = torch.sparse_coo_tensor(
        index_tensor.t(), value_tensor, size=size, device=device).to_dense()
    if normalize:
        row_sum = dense_tensor.sum(-1, keepdim=True)  # sum by row(tgt)
        # threshold on 1 to avoid div by 0
        torch.nn.functional.threshold(row_sum, 1, 1, inplace=True)
        dense_tensor.div_(row_sum)
    return dense_tensor


def extract_alignment(align_matrix, tgt_mask, src_lens, n_best):
    """
    Extract a batched align_matrix into its src indice alignment lists,
    with tgt_mask to filter out invalid tgt position as EOS/PAD.
    BOS already excluded from tgt_mask in order to match prediction.

    Args:
        align_matrix (Tensor): ``(B, tgt_len, src_len)``,
            attention head normalized by Softmax(dim=-1)
        tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD.
        src_lens (LongTensor): ``(B,)``, containing valid src length
        n_best (int): a value indicating number of parallel translation.
        * B: denote flattened batch as B = batch_size * n_best.

    Returns:
        alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``,
         containing valid alignment matrix (or None if blank prediction)
         for each translation.
    """
    batch_size_n_best = align_matrix.size(0)
    assert batch_size_n_best % n_best == 0

    alignments = [[] for _ in range(batch_size_n_best // n_best)]

    # treat alignment matrix one by one as each have different lengths
    for i, (am_b, tgt_mask_b, src_len) in enumerate(
            zip(align_matrix, tgt_mask, src_lens)):
        valid_tgt = ~tgt_mask_b
        valid_tgt_len = valid_tgt.sum()
        if valid_tgt_len == 0:
            # No alignment if not exist valid tgt token
            valid_alignment = None
        else:
            # get valid alignment (sub-matrix from full paded aligment matrix)
            am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \
                               .view(valid_tgt_len, -1)
            valid_alignment = am_valid_tgt[:, :src_len]  # only keep valid src
        alignments[i // n_best].append(valid_alignment)

    return alignments


def build_align_pharaoh(valid_alignment):
    """Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs,
    or empty list if it's None.
    """
    align_pairs = []
    if isinstance(valid_alignment, torch.Tensor):
        tgt_align_src_id = valid_alignment.argmax(dim=-1)

        for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()):
            align_pairs.append(str(src_id) + "-" + str(tgt_id))
        align_pairs.sort(key=lambda x: int(x.split('-')[-1]))  # sort by tgt_id
        align_pairs.sort(key=lambda x: int(x.split('-')[0]))  # sort by src_id
    return align_pairs


def to_word_align(src, tgt, subword_align, m_src='joiner', m_tgt='joiner'):
    """Convert subword alignment to word alignment.

    Args:
        src (string): tokenized sentence in source language.
        tgt (string): tokenized sentence in target language.
        subword_align (string): align_pharaoh correspond to src-tgt.
        m_src (string): tokenization mode used in src,
            can be ["joiner", "spacer"].
        m_tgt (string): tokenization mode used in tgt,
            can be ["joiner", "spacer"].

    Returns:
        word_align (string): converted alignments correspand to
            detokenized src-tgt.
    """
    assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!"
    assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!"

    src, tgt = src.strip().split(), tgt.strip().split()
    subword_align = {(int(a), int(b)) for a, b in (x.split("-")
                     for x in subword_align.split())}

    src_map = (subword_map_by_spacer(src) if m_src == 'spacer'
               else subword_map_by_joiner(src))

    tgt_map = (subword_map_by_spacer(src) if m_tgt == 'spacer'
               else subword_map_by_joiner(src))

    word_align = list({"{}-{}".format(src_map[a], tgt_map[b])
                       for a, b in subword_align})
    word_align.sort(key=lambda x: int(x.split('-')[-1]))  # sort by tgt_id
    word_align.sort(key=lambda x: int(x.split('-')[0]))  # sort by src_id
    return " ".join(word_align)


# Helper functions
def begin_uppercase(token):
    return token == SubwordMarker.BEGIN_UPPERCASE


def end_uppercase(token):
    return token == SubwordMarker.END_UPPERCASE


def begin_case(token):
    return token == SubwordMarker.BEGIN_CASED


def case_markup(token):
    return begin_uppercase(token) \
        or end_uppercase(token) \
        or begin_case(token)


def subword_map_by_joiner(subwords,
                          original_subwords=None,
                          marker=SubwordMarker.JOINER):
    """Return word id for each subword token (annotate by joiner)."""

    flags = [1] * len(subwords)
    j = 0
    finished = True
    for i, tok in enumerate(subwords):

        previous_tok = subwords[i-1] if i else ""  # Previous N-1 token
        previous_tok_2 = subwords[i-2] if i > 1 else ""  # Previous N-2 token
        # Keeps track of the original words/subwords
        # ('prior_tokenization' option)
        current_original_subword = "" if not original_subwords \
            else original_subwords[j] if j < len(original_subwords) else ""

        if tok.startswith(marker) and tok != current_original_subword:
            flags[i] = 0
        elif (previous_tok.endswith(marker)
                or begin_case(previous_tok)
                or begin_uppercase(previous_tok)) \
                and not finished:
            flags[i] = 0
        elif previous_tok_2.endswith(marker) \
                and case_markup(previous_tok) \
                and not finished:
            flags[i] = 0
        elif end_uppercase(tok) and tok != current_original_subword:
            flags[i] = 0
        else:
            finished = False
            if tok == current_original_subword:
                finished = True
            j += 1

    flags[0] = 0
    word_group = list(accumulate(flags))

    if original_subwords:
        assert max(word_group) < len(original_subwords)
    return word_group


def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER):
    """Return word id for each subword token (annotate by spacer)."""
    flags = [0] * len(subwords)
    for i, tok in enumerate(subwords):
        if marker in tok:
            if case_markup(tok.replace(marker, "")):
                if i < len(subwords)-1:
                    flags[i] = 1
            else:
                if i > 0:
                    previous = subwords[i-1].replace(marker, "")
                    if not case_markup(previous):
                        flags[i] = 1

    # In case there is a final case_markup when new_spacer is on
    for i in range(1, len(subwords)-1):
        if case_markup(subwords[-i]):
            flags[-i] = 0
        elif subwords[-i] == marker:
            flags[-i] = 0
            break

    word_group = list(accumulate(flags))
    if word_group[0] == 1:  # when dummy prefix is set
        word_group = [item - 1 for item in word_group]
    return word_group