File size: 4,410 Bytes
3424266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
import os
from collections import defaultdict
from typing import Optional, Union, List

from tokenizers import AddedToken, decoders, trainers
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer


def generate_sentinel_tokens(num=100, start_id=0):
    tokens = [
        AddedToken(content=f"[S_{i}]", single_word=True, normalized=False)
        for i in range(start_id, num + start_id)
    ]

    return tokens

def generate_coord_tokens(bins=1000):
    """Extra tokens that are used for bounding box coordinates, 
    xmin, ymin, xmax, ymax, but also other modalities like color
    maps, metadata, or poses.
    """
    tokens = []
    coords_str = ["v0={}", "v1={}", "v2={}", "v3={}"]

    for s in coords_str:
        for i in range(bins):
            tokens.append(AddedToken(content=s.format(i), single_word=True, normalized=False))

    return tokens

def generate_object_class_tokens(dataset="coco"):
    with open(os.path.join(os.path.dirname(__file__), 'object_classes.json')) as f:
        object_classes = json.load(f)[dataset]

    tokens = [
        AddedToken(content=class_name, single_word=True, normalized=True)
        for class_name in object_classes
    ]

    return tokens


def train_unified_wordpiece_tokenizer(
        files,
        vocab_size,
        sentinel_tokens: List[Union[str, AddedToken]] = None,
        coord_tokens: List[Union[str, AddedToken]] = None,
        object_class_tokens: List[Union[str, AddedToken]] = None,
        unk_token: Union[str, AddedToken] = "[UNK]",
        pad_token: Union[str, AddedToken] = "[PAD]",
        sos_token: Union[str, AddedToken] = "[SOS]",
        eos_token: Union[str, AddedToken] = "[EOS]",
        additional_special_tokens: List[Union[str, AddedToken]] = None,
        min_frequency=0,
        clean_text: bool = True,
        handle_chinese_chars: bool = True,
        strip_accents: Optional[bool] = None,
        lowercase: bool = True,
        wordpieces_prefix: str = "##",
        show_progress=True,
):
    tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))

    tokenizer.normalizer = BertNormalizer(
            clean_text=clean_text,
            handle_chinese_chars=handle_chinese_chars,
            strip_accents=strip_accents,
            lowercase=lowercase,
        )
    tokenizer.pre_tokenizer = BertPreTokenizer()
    tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)

    special_tokens = []
    special_tokens.append(pad_token)
    special_tokens.append(unk_token)
    special_tokens.append(sos_token)
    special_tokens.append(eos_token)

    if sentinel_tokens is not None:
        special_tokens.extend(sentinel_tokens)
    if coord_tokens is not None:
        special_tokens.extend(coord_tokens)
    if object_class_tokens is not None:
        special_tokens.extend(object_class_tokens)
    if additional_special_tokens is not None:
        special_tokens.extend(additional_special_tokens)

    trainer = trainers.WordPieceTrainer(
        vocab_size=vocab_size,
        min_frequency=min_frequency,
        show_progress=show_progress,
        continuing_subword_prefix=wordpieces_prefix,
        special_tokens=special_tokens,
    )

    if isinstance(files, str):
        files = [files]

    tokenizer.train(files, trainer=trainer)

    return tokenizer


def get_sentinel_to_id_mapping(tokenizer, match_str="[S_"):
    sentinel_tokens = {k: v for k, v in tokenizer.get_vocab().items() if k.startswith(match_str)}
    # Extract the sentinel token id, the id is of the form "[S_0]", "[S_1]", etc.
    sentinel_to_id = {int(k.split("_")[1][:-1]): v for k, v in sorted(sentinel_tokens.items(), key=lambda x:x[1])}
    return sentinel_to_id


def split_by_sentinel(seq_ids, sentinel_ids):
    splits = defaultdict(list)
    cur_sentinel = None
    for token in seq_ids:
        if token in sentinel_ids:
            cur_sentinel = token
        else:
            splits[cur_sentinel].append(token)

    return splits


def merge_span_masking(input_seq, decoder_seq, sentinel_ids):
    decoder_splits = split_by_sentinel(decoder_seq, sentinel_ids)
    out_seq = []
    for token in input_seq:
        if token in sentinel_ids:
            out_seq.extend(decoder_splits[token])
        else:
            out_seq.append(token)
    return out_seq