athnlp2025_tokenization / character_util.py
gpantaz
Add application file
b4dc5cb
raw
history blame
6.09 kB
import json
import os
from pathlib import Path
from typing import Literal
import numpy as np
import pandas as pd
from utils.lang_util import detect_language_by_unicode, language_ranges
from utils.log_util import logger
from utils.text_util import contains_digit, get_space_count
from vocab import tokenizer_factory
CURRENT_DIR = Path.parent(Path.resolve(__file__))
cache = {}
default_columns = ["digit", "zh"]
def text_to_unicode(text: str) -> str:
"""Convert text to unicode representation."""
return "".join(rf"\u{ord(character):04X}" for character in text)
def calculate_dist(token_lens: list[int]) -> str:
"""Calculate the distribution of token lengths."""
if not token_lens:
return "-"
return f"{min(token_lens)},{round(np.median(token_lens))},{max(token_lens)}"
def iter_vocab(
tokenizer_name: str,
from_cache: bool = True,
cache_dir: str = "stats",
) -> pd.DataFrame | dict:
""":param tokenizer_name:
:param from_cache:
:param cache_dir:
:return:
"""
tokenizer_config = tokenizer_factory.get_tokenizer_config(tokenizer_name)
cache_dir = os.path.join(CURRENT_DIR, cache_dir)
os.makedirs(cache_dir, exist_ok=True)
# load from cache
cache_path = os.path.join(cache_dir, "character_stats.json")
if not cache and os.path.exists(cache_path):
with open(cache_path, encoding="utf-8") as f_tmp:
cache.update(json.load(f_tmp))
if from_cache and tokenizer_name in cache:
# logger.info(f"load {tokenizer_config.name_or_path} from cache")
return cache[tokenizer_name]
tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name)
tokens_by_lang = {lang[1]: [] for lang in language_ranges}
digit_tokens = []
space_tokens = []
byte_tokens = []
buffer = []
for token_id in range(tokenizer.vocab_size):
# for token_id in tokenizer.get_vocab():
# for token_id in range(len(tokenizer)):
decode_str = tokenizer.decode([token_id], skip_special_tokens=False)
token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0]
tags = []
if token is None: # 有些词典有空的id(不连续)
continue
if isinstance(token, bytes):
token = token.decode("utf-8", errors="ignore")
if hasattr(tokenizer, "sp_model") and tokenizer.sp_model.is_byte(token_id):
tags.append("is_byte")
byte_tokens.append(token)
language_tags = detect_language_by_unicode(decode_str)
for language in language_tags:
tokens_by_lang[language[1]].append(decode_str)
if contains_digit(decode_str):
tags.append("digit")
digit_tokens.append(decode_str)
space_count = get_space_count(decode_str)
if space_count > 0:
space_tokens.append(decode_str)
buffer.append(
json.dumps(
{
"id": token_id,
"token": token,
"token_decode": decode_str,
"token_dumps": json.dumps(token),
"token_unicode": text_to_unicode(token),
"token_len": len(decode_str),
},
ensure_ascii=False,
)
+ "\n"
)
result = {
"tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name),
"organization": tokenizer_config.org,
"vocab_size": len(tokenizer),
"num(digit)": len(digit_tokens),
"len(digit)": calculate_dist([len(token) for token in digit_tokens]),
"num(space)": len(space_tokens),
"len(space)": calculate_dist([len(token) for token in space_tokens]),
}
for lang, tokens in tokens_by_lang.items():
result[f"num({lang})"] = len(tokens)
result["len(" + lang + ")"] = calculate_dist([len(token) for token in tokens])
out_path = os.path.join(
cache_dir, f"iter_vocab/{tokenizer_name.replace('/', '_')}.vocab.jsonl"
)
with open(out_path, "w", encoding="utf-8") as f_out:
for line in buffer:
f_out.write(line)
len_before = len(cache)
cache[tokenizer_name] = result
len_after = len(cache)
logger.info(f"saving {tokenizer_name} to memory and file cache: {len_before}->{len_after}")
with open(cache_path, "w", encoding="utf-8") as f_out:
f_out.write(json.dumps(cache, ensure_ascii=False, indent=2))
return result
def to_dataframe(stats: dict[str, Any], columns: list[str]) -> pd.DataFrame:
table = []
for stat in stats.values():
filtered_stat = {}
for k, v in stat.items():
if not k.startswith("num") and not k.startswith("len"):
filtered_stat[k] = v
if any(column in k for column in columns):
k = k.replace("ja-kana", "kana")
filtered_stat[k] = v
table.append(filtered_stat)
return pd.DataFrame(table)
def get_character_table(
tokenizer_filter: str | None = None,
columns: list | None = None,
return_type: Literal["dict", "dataframe"] | None = "dataframe",
) -> pd.DataFrame | dict:
logger.info(f"columns: {columns}, tokenizer_filter: {tokenizer_filter}")
stats = {}
if columns is None:
columns = default_columns
if tokenizer_filter is not None:
tokenizer_names = [
tokenizer_config.name_or_path
for tokenizer_config in tokenizer_factory.all_tokenizer_configs
if tokenizer_filter.lower() in tokenizer_config.name_or_path.lower()
]
else:
tokenizer_names = tokenizer_factory.all_tokenizer_names
for tokenizer_name in tokenizer_names:
stat = iter_vocab(tokenizer_name)
stats[tokenizer_name] = stat
if return_type == "dataframe":
stats = to_dataframe(stats, columns)
return stats
if __name__ == "__main__":
# aa = get_character_table(tokenizer_filter="baichuan")
df = get_character_table()
logger.info(f"\n{df.to_markdown(index=False)}")