Spaces:
Runtime error
Runtime error
import json | |
import os | |
from collections import defaultdict | |
from typing import Optional, Union, List | |
from tokenizers import AddedToken, decoders, trainers | |
from tokenizers import Tokenizer | |
from tokenizers.models import WordPiece | |
from tokenizers.normalizers import BertNormalizer | |
from tokenizers.pre_tokenizers import BertPreTokenizer | |
def generate_sentinel_tokens(num=100, start_id=0): | |
tokens = [ | |
AddedToken(content=f"[S_{i}]", single_word=True, normalized=False) | |
for i in range(start_id, num + start_id) | |
] | |
return tokens | |
def generate_coord_tokens(bins=1000): | |
"""Extra tokens that are used for bounding box coordinates, | |
xmin, ymin, xmax, ymax, but also other modalities like color | |
maps, metadata, or poses. | |
""" | |
tokens = [] | |
coords_str = ["v0={}", "v1={}", "v2={}", "v3={}"] | |
for s in coords_str: | |
for i in range(bins): | |
tokens.append(AddedToken(content=s.format(i), single_word=True, normalized=False)) | |
return tokens | |
def generate_object_class_tokens(dataset="coco"): | |
with open(os.path.join(os.path.dirname(__file__), 'object_classes.json')) as f: | |
object_classes = json.load(f)[dataset] | |
tokens = [ | |
AddedToken(content=class_name, single_word=True, normalized=True) | |
for class_name in object_classes | |
] | |
return tokens | |
def train_unified_wordpiece_tokenizer( | |
files, | |
vocab_size, | |
sentinel_tokens: List[Union[str, AddedToken]] = None, | |
coord_tokens: List[Union[str, AddedToken]] = None, | |
object_class_tokens: List[Union[str, AddedToken]] = None, | |
unk_token: Union[str, AddedToken] = "[UNK]", | |
pad_token: Union[str, AddedToken] = "[PAD]", | |
sos_token: Union[str, AddedToken] = "[SOS]", | |
eos_token: Union[str, AddedToken] = "[EOS]", | |
additional_special_tokens: List[Union[str, AddedToken]] = None, | |
min_frequency=0, | |
clean_text: bool = True, | |
handle_chinese_chars: bool = True, | |
strip_accents: Optional[bool] = None, | |
lowercase: bool = True, | |
wordpieces_prefix: str = "##", | |
show_progress=True, | |
): | |
tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token))) | |
tokenizer.normalizer = BertNormalizer( | |
clean_text=clean_text, | |
handle_chinese_chars=handle_chinese_chars, | |
strip_accents=strip_accents, | |
lowercase=lowercase, | |
) | |
tokenizer.pre_tokenizer = BertPreTokenizer() | |
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix) | |
special_tokens = [] | |
special_tokens.append(pad_token) | |
special_tokens.append(unk_token) | |
special_tokens.append(sos_token) | |
special_tokens.append(eos_token) | |
if sentinel_tokens is not None: | |
special_tokens.extend(sentinel_tokens) | |
if coord_tokens is not None: | |
special_tokens.extend(coord_tokens) | |
if object_class_tokens is not None: | |
special_tokens.extend(object_class_tokens) | |
if additional_special_tokens is not None: | |
special_tokens.extend(additional_special_tokens) | |
trainer = trainers.WordPieceTrainer( | |
vocab_size=vocab_size, | |
min_frequency=min_frequency, | |
show_progress=show_progress, | |
continuing_subword_prefix=wordpieces_prefix, | |
special_tokens=special_tokens, | |
) | |
if isinstance(files, str): | |
files = [files] | |
tokenizer.train(files, trainer=trainer) | |
return tokenizer | |
def get_sentinel_to_id_mapping(tokenizer, match_str="[S_"): | |
sentinel_tokens = {k: v for k, v in tokenizer.get_vocab().items() if k.startswith(match_str)} | |
# Extract the sentinel token id, the id is of the form "[S_0]", "[S_1]", etc. | |
sentinel_to_id = {int(k.split("_")[1][:-1]): v for k, v in sorted(sentinel_tokens.items(), key=lambda x:x[1])} | |
return sentinel_to_id | |
def split_by_sentinel(seq_ids, sentinel_ids): | |
splits = defaultdict(list) | |
cur_sentinel = None | |
for token in seq_ids: | |
if token in sentinel_ids: | |
cur_sentinel = token | |
else: | |
splits[cur_sentinel].append(token) | |
return splits | |
def merge_span_masking(input_seq, decoder_seq, sentinel_ids): | |
decoder_splits = split_by_sentinel(decoder_seq, sentinel_ids) | |
out_seq = [] | |
for token in input_seq: | |
if token in sentinel_ids: | |
out_seq.extend(decoder_splits[token]) | |
else: | |
out_seq.append(token) | |
return out_seq | |