aroraaman's picture
Add all of `fourm`
3424266
raw
history blame
4.41 kB
import json
import os
from collections import defaultdict
from typing import Optional, Union, List
from tokenizers import AddedToken, decoders, trainers
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.normalizers import BertNormalizer
from tokenizers.pre_tokenizers import BertPreTokenizer
def generate_sentinel_tokens(num=100, start_id=0):
tokens = [
AddedToken(content=f"[S_{i}]", single_word=True, normalized=False)
for i in range(start_id, num + start_id)
]
return tokens
def generate_coord_tokens(bins=1000):
"""Extra tokens that are used for bounding box coordinates,
xmin, ymin, xmax, ymax, but also other modalities like color
maps, metadata, or poses.
"""
tokens = []
coords_str = ["v0={}", "v1={}", "v2={}", "v3={}"]
for s in coords_str:
for i in range(bins):
tokens.append(AddedToken(content=s.format(i), single_word=True, normalized=False))
return tokens
def generate_object_class_tokens(dataset="coco"):
with open(os.path.join(os.path.dirname(__file__), 'object_classes.json')) as f:
object_classes = json.load(f)[dataset]
tokens = [
AddedToken(content=class_name, single_word=True, normalized=True)
for class_name in object_classes
]
return tokens
def train_unified_wordpiece_tokenizer(
files,
vocab_size,
sentinel_tokens: List[Union[str, AddedToken]] = None,
coord_tokens: List[Union[str, AddedToken]] = None,
object_class_tokens: List[Union[str, AddedToken]] = None,
unk_token: Union[str, AddedToken] = "[UNK]",
pad_token: Union[str, AddedToken] = "[PAD]",
sos_token: Union[str, AddedToken] = "[SOS]",
eos_token: Union[str, AddedToken] = "[EOS]",
additional_special_tokens: List[Union[str, AddedToken]] = None,
min_frequency=0,
clean_text: bool = True,
handle_chinese_chars: bool = True,
strip_accents: Optional[bool] = None,
lowercase: bool = True,
wordpieces_prefix: str = "##",
show_progress=True,
):
tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
tokenizer.normalizer = BertNormalizer(
clean_text=clean_text,
handle_chinese_chars=handle_chinese_chars,
strip_accents=strip_accents,
lowercase=lowercase,
)
tokenizer.pre_tokenizer = BertPreTokenizer()
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
special_tokens = []
special_tokens.append(pad_token)
special_tokens.append(unk_token)
special_tokens.append(sos_token)
special_tokens.append(eos_token)
if sentinel_tokens is not None:
special_tokens.extend(sentinel_tokens)
if coord_tokens is not None:
special_tokens.extend(coord_tokens)
if object_class_tokens is not None:
special_tokens.extend(object_class_tokens)
if additional_special_tokens is not None:
special_tokens.extend(additional_special_tokens)
trainer = trainers.WordPieceTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
show_progress=show_progress,
continuing_subword_prefix=wordpieces_prefix,
special_tokens=special_tokens,
)
if isinstance(files, str):
files = [files]
tokenizer.train(files, trainer=trainer)
return tokenizer
def get_sentinel_to_id_mapping(tokenizer, match_str="[S_"):
sentinel_tokens = {k: v for k, v in tokenizer.get_vocab().items() if k.startswith(match_str)}
# Extract the sentinel token id, the id is of the form "[S_0]", "[S_1]", etc.
sentinel_to_id = {int(k.split("_")[1][:-1]): v for k, v in sorted(sentinel_tokens.items(), key=lambda x:x[1])}
return sentinel_to_id
def split_by_sentinel(seq_ids, sentinel_ids):
splits = defaultdict(list)
cur_sentinel = None
for token in seq_ids:
if token in sentinel_ids:
cur_sentinel = token
else:
splits[cur_sentinel].append(token)
return splits
def merge_span_masking(input_seq, decoder_seq, sentinel_ids):
decoder_splits = split_by_sentinel(decoder_seq, sentinel_ids)
out_seq = []
for token in input_seq:
if token in sentinel_ids:
out_seq.extend(decoder_splits[token])
else:
out_seq.append(token)
return out_seq