Spaces:
Runtime error
Runtime error
File size: 4,410 Bytes
3424266 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import json
import os
from collections import defaultdict
from typing import Optional, Union, List
from tokenizers import AddedToken, decoders, trainers
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
def generate_sentinel_tokens(num=100, start_id=0):
tokens = [
AddedToken(content=f"[S_{i}]", single_word=True, normalized=False)
for i in range(start_id, num + start_id)
]
return tokens
def generate_coord_tokens(bins=1000):
"""Extra tokens that are used for bounding box coordinates,
xmin, ymin, xmax, ymax, but also other modalities like color
maps, metadata, or poses.
"""
tokens = []
coords_str = ["v0={}", "v1={}", "v2={}", "v3={}"]
for s in coords_str:
for i in range(bins):
tokens.append(AddedToken(content=s.format(i), single_word=True, normalized=False))
return tokens
def generate_object_class_tokens(dataset="coco"):
with open(os.path.join(os.path.dirname(__file__), 'object_classes.json')) as f:
object_classes = json.load(f)[dataset]
tokens = [
AddedToken(content=class_name, single_word=True, normalized=True)
for class_name in object_classes
]
return tokens
def train_unified_wordpiece_tokenizer(
files,
vocab_size,
sentinel_tokens: List[Union[str, AddedToken]] = None,
coord_tokens: List[Union[str, AddedToken]] = None,
object_class_tokens: List[Union[str, AddedToken]] = None,
unk_token: Union[str, AddedToken] = "[UNK]",
pad_token: Union[str, AddedToken] = "[PAD]",
sos_token: Union[str, AddedToken] = "[SOS]",
eos_token: Union[str, AddedToken] = "[EOS]",
additional_special_tokens: List[Union[str, AddedToken]] = None,
min_frequency=0,
clean_text: bool = True,
handle_chinese_chars: bool = True,
strip_accents: Optional[bool] = None,
lowercase: bool = True,
wordpieces_prefix: str = "##",
show_progress=True,
):
tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
tokenizer.normalizer = BertNormalizer(
clean_text=clean_text,
handle_chinese_chars=handle_chinese_chars,
strip_accents=strip_accents,
lowercase=lowercase,
)
tokenizer.pre_tokenizer = BertPreTokenizer()
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
special_tokens = []
special_tokens.append(pad_token)
special_tokens.append(unk_token)
special_tokens.append(sos_token)
special_tokens.append(eos_token)
if sentinel_tokens is not None:
special_tokens.extend(sentinel_tokens)
if coord_tokens is not None:
special_tokens.extend(coord_tokens)
if object_class_tokens is not None:
special_tokens.extend(object_class_tokens)
if additional_special_tokens is not None:
special_tokens.extend(additional_special_tokens)
trainer = trainers.WordPieceTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
show_progress=show_progress,
continuing_subword_prefix=wordpieces_prefix,
special_tokens=special_tokens,
)
if isinstance(files, str):
files = [files]
tokenizer.train(files, trainer=trainer)
return tokenizer
def get_sentinel_to_id_mapping(tokenizer, match_str="[S_"):
sentinel_tokens = {k: v for k, v in tokenizer.get_vocab().items() if k.startswith(match_str)}
# Extract the sentinel token id, the id is of the form "[S_0]", "[S_1]", etc.
sentinel_to_id = {int(k.split("_")[1][:-1]): v for k, v in sorted(sentinel_tokens.items(), key=lambda x:x[1])}
return sentinel_to_id
def split_by_sentinel(seq_ids, sentinel_ids):
splits = defaultdict(list)
cur_sentinel = None
for token in seq_ids:
if token in sentinel_ids:
cur_sentinel = token
else:
splits[cur_sentinel].append(token)
return splits
def merge_span_masking(input_seq, decoder_seq, sentinel_ids):
decoder_splits = split_by_sentinel(decoder_seq, sentinel_ids)
out_seq = []
for token in input_seq:
if token in sentinel_ids:
out_seq.extend(decoder_splits[token])
else:
out_seq.append(token)
return out_seq
|