File size: 6,637 Bytes
847e3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Tweaked version of corresponding AllenNLP file"""
import logging
from collections import defaultdict
from typing import Dict, List, Callable

from allennlp.common.util import pad_sequence_to_length
from allennlp.data.token_indexers.token_indexer import TokenIndexer
from allennlp.data.tokenizers.token import Token
from allennlp.data.vocabulary import Vocabulary
from overrides import overrides
from transformers import AutoTokenizer

from utils.helpers import START_TOKEN

from gector.tokenization import tokenize_batch
import copy

logger = logging.getLogger(__name__)


# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.


class TokenizerIndexer(TokenIndexer[int]):
    """
    A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
    If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer``
    subclass rather than this base class.

    Parameters
    ----------
    tokenizer : ``Callable[[str], List[str]]``
        A function that does the actual tokenization.
    max_pieces : int, optional (default: 512)
        The BERT embedder uses positional embeddings and so has a corresponding
        maximum length for its input ids. Any inputs longer than this will
        either be truncated (default), or be split apart and batched using a
        sliding window.
    token_min_padding_length : ``int``, optional (default=``0``)
        See :class:`TokenIndexer`.
    """

    def __init__(self,
                 tokenizer: Callable[[str], List[str]],
                 max_pieces: int = 512,
                 max_pieces_per_token: int = 3,
                 token_min_padding_length: int = 0) -> None:
        super().__init__(token_min_padding_length)

        # The BERT code itself does a two-step tokenization:
        #    sentence -> [words], and then word -> [wordpieces]
        # In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``,
        # and this token indexer handles the second.

        self.tokenizer = tokenizer
        self.max_pieces_per_token = max_pieces_per_token
        self.max_pieces = max_pieces
        self.max_pieces_per_sentence = 80

    @overrides
    def tokens_to_indices(self, tokens: List[Token],
                          vocabulary: Vocabulary,
                          index_name: str) -> Dict[str, List[int]]:
        text = [token.text for token in tokens]
        batch_tokens = [text]

        output_fast = tokenize_batch(self.tokenizer,
                                     batch_tokens,
                                     max_bpe_length=self.max_pieces,
                                     max_bpe_pieces=self.max_pieces_per_token)
        output_fast = {k: v[0] for k, v in output_fast.items()}
        return output_fast

    @overrides
    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
        # If we only use pretrained models, we don't need to do anything here.
        pass

    @overrides
    def get_padding_token(self) -> int:
        return 0

    @overrides
    def get_padding_lengths(self, token: int) -> Dict[str, int]:  # pylint: disable=unused-argument
        return {}

    @overrides
    def pad_token_sequence(self,
                           tokens: Dict[str, List[int]],
                           desired_num_tokens: Dict[str, int],
                           padding_lengths: Dict[str, int]) -> Dict[str, List[int]]:  # pylint: disable=unused-argument
        return {key: pad_sequence_to_length(val, desired_num_tokens[key])
                for key, val in tokens.items()}

    @overrides
    def get_keys(self, index_name: str) -> List[str]:
        """
        We need to override this because the indexer generates multiple keys.
        """
        # pylint: disable=no-self-use
        return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]


class PretrainedBertIndexer(TokenizerIndexer):
    # pylint: disable=line-too-long
    """
    A ``TokenIndexer`` corresponding to a pretrained BERT model.

    Parameters
    ----------
    pretrained_model: ``str``
        Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
        or the path to the .txt file with its vocabulary.
        If the name is a key in the list of pretrained models at
        https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33
        the corresponding path will be used; otherwise it will be interpreted as a path or URL.
    do_lowercase: ``bool``, optional (default = True)
        Whether to lowercase the tokens before converting to wordpiece ids.
    max_pieces: int, optional (default: 512)
        The BERT embedder uses positional embeddings and so has a corresponding
        maximum length for its input ids. Any inputs longer than this will
        either be truncated (default), or be split apart and batched using a
        sliding window.
    """

    def __init__(self,
                 pretrained_model: str,
                 do_lowercase: bool = True,
                 max_pieces: int = 512,
                 max_pieces_per_token: int = 5,
                 special_tokens_fix: int = 0) -> None:

        if pretrained_model.endswith("-cased") and do_lowercase:
            logger.warning("Your BERT model appears to be cased, "
                           "but your indexer is lowercasing tokens.")
        elif pretrained_model.endswith("-uncased") and not do_lowercase:
            logger.warning("Your BERT model appears to be uncased, "
                           "but your indexer is not lowercasing tokens.")

        model_name = copy.deepcopy(pretrained_model)

        model_tokenizer = AutoTokenizer.from_pretrained(
            model_name, do_lower_case=do_lowercase, do_basic_tokenize=False, use_fast=True)

        # to adjust all tokenizers
        if hasattr(model_tokenizer, 'encoder'):
            model_tokenizer.vocab = model_tokenizer.encoder
        if hasattr(model_tokenizer, 'sp_model'):
            model_tokenizer.vocab = defaultdict(lambda: 1)
            for i in range(model_tokenizer.sp_model.get_piece_size()):
                model_tokenizer.vocab[model_tokenizer.sp_model.id_to_piece(i)] = i

        if special_tokens_fix:
            model_tokenizer.add_tokens([START_TOKEN])
            model_tokenizer.vocab[START_TOKEN] = len(model_tokenizer) - 1

        super().__init__(tokenizer=model_tokenizer,
                         max_pieces=max_pieces,
                         max_pieces_per_token=max_pieces_per_token
                        )