File size: 8,874 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# coding: utf-8
# Copyright 2019 Sinovation Ventures AI Institute
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for ngram for ZEN2 model."""

import os
import logging
import math
import numpy as np
import torch
from transformers import cached_path

NGRAM_DICT_NAME = 'ngram.txt'

logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
    'IDEA-CCNL/Erlangshen-ZEN2-345M-Chinese': 'https://huggingface.co/IDEA-CCNL/Erlangshen-ZEN2-345M-Chinese/resolve/main/ngram.txt',
    'IDEA-CCNL/Erlangshen-ZEN2-668M-Chinese': 'https://huggingface.co/IDEA-CCNL/Erlangshen-ZEN2-668M-Chinese/resolve/main/ngram.txt',
}


class ZenNgramDict(object):
    """
    Dict class to store the ngram
    """

    def __init__(self, ngram_freq_path, tokenizer=None, max_ngram_in_seq=128):
        """Constructs ZenNgramDict

        :param ngram_freq_path: ngrams with frequency
        """
        if os.path.isdir(ngram_freq_path):
            ngram_freq_path = os.path.join(ngram_freq_path, NGRAM_DICT_NAME)
        self.ngram_freq_path = ngram_freq_path
        self.max_ngram_in_seq = max_ngram_in_seq
        self.max_ngram_len = 8
        self.id_to_ngram_list = ["[pad]"]
        self.ngram_to_id_dict = {"[pad]": 0}
        self.ngram_to_freq_dict = {}

        logger.info("loading ngram frequency file {}".format(ngram_freq_path))
        with open(ngram_freq_path, "r", encoding="utf-8") as fin:
            for i, line in enumerate(fin):
                items = line.strip().split(",")
                if len(items) != 2:
                    continue
                ngram, freq = items
                # self.ngram_to_freq_dict[ngram] = int(freq)
                if tokenizer:
                    tokens = tuple(tokenizer.tokenize(ngram))
                    if len([token for token in tokens if "[UNK]" in token]) > 0:
                        tokens = ngram
                else:
                    tokens = tuple(ngram.split(" "))
                self.id_to_ngram_list.append(tokens)
                self.ngram_to_id_dict[tokens] = i + 1
                self.ngram_to_freq_dict[tokens] = int(freq)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, **kwargs):
        """
        Instantiate a PreTrainedBertModel from a pre-trained model file.
        Download and cache the pre-trained model file if needed.
        """
        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
            ngram_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
            if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
                logger.warning("The pre-trained model you are loading is a cased model but you have not set "
                               "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
                               "you may want to check this behavior.")
                kwargs['do_lower_case'] = False
            elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
                logger.warning("The pre-trained model you are loading is an uncased model but you have set "
                               "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
                               "but you may want to check this behavior.")
                kwargs['do_lower_case'] = True
        else:
            ngram_file = pretrained_model_name_or_path
        if os.path.isdir(ngram_file):
            ngram_file = os.path.join(ngram_file, NGRAM_DICT_NAME)
        # redirect to the cache, if necessary
        try:
            resolved_ngram_file = cached_path(ngram_file, cache_dir=cache_dir)
        except EnvironmentError:
            if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
                logger.error(
                    "Couldn't reach server at '{}' to download vocabulary.".format(
                        ngram_file))
            else:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find any file "
                    "associated to this path or url.".format(
                        pretrained_model_name_or_path,
                        ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
                        ngram_file))
            return None
        if resolved_ngram_file == ngram_file:
            logger.info("loading vocabulary file {}".format(ngram_file))
        else:
            logger.info("loading vocabulary file {} from cache at {}".format(
                ngram_file, resolved_ngram_file))
        # Instantiate ngram.
        ngram_dict = cls(resolved_ngram_file, **kwargs)
        return ngram_dict

    def save(self, ngram_freq_path):
        ngram_freq_path = os.path.join(ngram_freq_path, NGRAM_DICT_NAME)
        with open(ngram_freq_path, "w+", encoding="utf-8") as fout:
            for ngram, freq in self.ngram_to_freq_dict.items():
                fout.write("{},{}\n".format(" ".join(ngram), freq))


def extract_ngram_feature(tokens, ngram_dict, max_seq_len, seg_id_limit):
    # ----------- code for ngram BEGIN-----------
    ngram_matches = []
    #  Filter the word segment from 2 to max_ngram_len to check whether there is a word
    max_gram_n = ngram_dict.max_ngram_len
    for p in range(2, max_gram_n):
        for q in range(0, len(tokens) - p + 1):
            character_segment = tokens[q:q + p]
            # j is the starting position of the word
            # i is the length of the current word
            character_segment = tuple(character_segment)
            if character_segment in ngram_dict.ngram_to_id_dict:
                ngram_index = ngram_dict.ngram_to_id_dict[character_segment]
                ngram_freq = ngram_dict.ngram_to_freq_dict[character_segment]
                ngram_matches.append([ngram_index, q, p, character_segment, ngram_freq])

    # shuffle(ngram_matches)
    ngram_matches = sorted(ngram_matches, key=lambda s: s[0])
    # max_word_in_seq_proportion = max_word_in_seq
    max_word_in_seq_proportion = math.ceil((len(tokens) / max_seq_len) * ngram_dict.max_ngram_in_seq)
    if len(ngram_matches) > max_word_in_seq_proportion:
        ngram_matches = ngram_matches[:max_word_in_seq_proportion]
    ngram_ids = [ngram[0] for ngram in ngram_matches]
    ngram_positions = [ngram[1] for ngram in ngram_matches]
    ngram_lengths = [ngram[2] for ngram in ngram_matches]
    ngram_tuples = [ngram[3] for ngram in ngram_matches]
    ngram_freqs = [ngram[4] for ngram in ngram_matches]
    ngram_seg_ids = [0 if position < seg_id_limit else 1 for position in
                     ngram_positions]

    ngram_mask_array = np.zeros(ngram_dict.max_ngram_in_seq, dtype=np.bool)
    ngram_mask_array[:len(ngram_ids)] = 1

    # Zero-pad up to the max word in seq length.
    padding = [0] * (ngram_dict.max_ngram_in_seq - len(ngram_ids))
    ngram_ids += padding
    ngram_positions += padding
    ngram_lengths += padding
    ngram_seg_ids += padding
    ngram_freqs += padding

    # ----------- code for ngram END-----------

    return {
        "ngram_ids": ngram_ids,
        "ngram_positions": ngram_positions,
        "ngram_lengths": ngram_lengths,
        "ngram_tuples": ngram_tuples,
        "ngram_seg_ids": ngram_seg_ids,
        "ngram_masks": ngram_mask_array,
        "ngram_freqs": ngram_freqs,
    }


def construct_ngram_matrix(ngram_data, max_seq_length):
    max_ngram_in_sequence = len(ngram_data["ngram_ids"])
    ngram_ids_num = len([x for x in ngram_data["ngram_masks"] if x == 1])

    ngram_positions_matrix = np.zeros(shape=(max_seq_length, max_ngram_in_sequence), dtype=np.float)
    for i in range(ngram_ids_num):
        ngram_positions_matrix[ngram_data["ngram_positions"][i]:
                               ngram_data["ngram_positions"][i] + ngram_data["ngram_lengths"][i], i] = \
            ngram_data["ngram_freqs"][i]
    ngram_positions_matrix_t = torch.from_numpy(ngram_positions_matrix.astype(np.float))
    ngram_positions_matrix_t = torch.div(ngram_positions_matrix_t,
                                         torch.stack([torch.sum(ngram_positions_matrix_t, 1)] * ngram_positions_matrix_t.size(1)).t() + 1e-10)

    return ngram_positions_matrix_t.numpy()