Spaces:
Running
Running
import json | |
import os | |
from pathlib import Path | |
from typing import Literal | |
import numpy as np | |
import pandas as pd | |
from utils.lang_util import detect_language_by_unicode, language_ranges | |
from utils.log_util import logger | |
from utils.text_util import contains_digit, get_space_count | |
from vocab import tokenizer_factory | |
CURRENT_DIR = Path.parent(Path.resolve(__file__)) | |
cache = {} | |
default_columns = ["digit", "zh"] | |
def text_to_unicode(text: str) -> str: | |
"""Convert text to unicode representation.""" | |
return "".join(rf"\u{ord(character):04X}" for character in text) | |
def calculate_dist(token_lens: list[int]) -> str: | |
"""Calculate the distribution of token lengths.""" | |
if not token_lens: | |
return "-" | |
return f"{min(token_lens)},{round(np.median(token_lens))},{max(token_lens)}" | |
def iter_vocab( | |
tokenizer_name: str, | |
from_cache: bool = True, | |
cache_dir: str = "stats", | |
) -> pd.DataFrame | dict: | |
""":param tokenizer_name: | |
:param from_cache: | |
:param cache_dir: | |
:return: | |
""" | |
tokenizer_config = tokenizer_factory.get_tokenizer_config(tokenizer_name) | |
cache_dir = os.path.join(CURRENT_DIR, cache_dir) | |
os.makedirs(cache_dir, exist_ok=True) | |
# load from cache | |
cache_path = os.path.join(cache_dir, "character_stats.json") | |
if not cache and os.path.exists(cache_path): | |
with open(cache_path, encoding="utf-8") as f_tmp: | |
cache.update(json.load(f_tmp)) | |
if from_cache and tokenizer_name in cache: | |
# logger.info(f"load {tokenizer_config.name_or_path} from cache") | |
return cache[tokenizer_name] | |
tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name) | |
tokens_by_lang = {lang[1]: [] for lang in language_ranges} | |
digit_tokens = [] | |
space_tokens = [] | |
byte_tokens = [] | |
buffer = [] | |
for token_id in range(tokenizer.vocab_size): | |
# for token_id in tokenizer.get_vocab(): | |
# for token_id in range(len(tokenizer)): | |
decode_str = tokenizer.decode([token_id], skip_special_tokens=False) | |
token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0] | |
tags = [] | |
if token is None: # 有些词典有空的id(不连续) | |
continue | |
if isinstance(token, bytes): | |
token = token.decode("utf-8", errors="ignore") | |
if hasattr(tokenizer, "sp_model") and tokenizer.sp_model.is_byte(token_id): | |
tags.append("is_byte") | |
byte_tokens.append(token) | |
language_tags = detect_language_by_unicode(decode_str) | |
for language in language_tags: | |
tokens_by_lang[language[1]].append(decode_str) | |
if contains_digit(decode_str): | |
tags.append("digit") | |
digit_tokens.append(decode_str) | |
space_count = get_space_count(decode_str) | |
if space_count > 0: | |
space_tokens.append(decode_str) | |
buffer.append( | |
json.dumps( | |
{ | |
"id": token_id, | |
"token": token, | |
"token_decode": decode_str, | |
"token_dumps": json.dumps(token), | |
"token_unicode": text_to_unicode(token), | |
"token_len": len(decode_str), | |
}, | |
ensure_ascii=False, | |
) | |
+ "\n" | |
) | |
result = { | |
"tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name), | |
"organization": tokenizer_config.org, | |
"vocab_size": len(tokenizer), | |
"num(digit)": len(digit_tokens), | |
"len(digit)": calculate_dist([len(token) for token in digit_tokens]), | |
"num(space)": len(space_tokens), | |
"len(space)": calculate_dist([len(token) for token in space_tokens]), | |
} | |
for lang, tokens in tokens_by_lang.items(): | |
result[f"num({lang})"] = len(tokens) | |
result["len(" + lang + ")"] = calculate_dist([len(token) for token in tokens]) | |
out_path = os.path.join( | |
cache_dir, f"iter_vocab/{tokenizer_name.replace('/', '_')}.vocab.jsonl" | |
) | |
with open(out_path, "w", encoding="utf-8") as f_out: | |
for line in buffer: | |
f_out.write(line) | |
len_before = len(cache) | |
cache[tokenizer_name] = result | |
len_after = len(cache) | |
logger.info(f"saving {tokenizer_name} to memory and file cache: {len_before}->{len_after}") | |
with open(cache_path, "w", encoding="utf-8") as f_out: | |
f_out.write(json.dumps(cache, ensure_ascii=False, indent=2)) | |
return result | |
def to_dataframe(stats: dict[str, Any], columns: list[str]) -> pd.DataFrame: | |
table = [] | |
for stat in stats.values(): | |
filtered_stat = {} | |
for k, v in stat.items(): | |
if not k.startswith("num") and not k.startswith("len"): | |
filtered_stat[k] = v | |
if any(column in k for column in columns): | |
k = k.replace("ja-kana", "kana") | |
filtered_stat[k] = v | |
table.append(filtered_stat) | |
return pd.DataFrame(table) | |
def get_character_table( | |
tokenizer_filter: str | None = None, | |
columns: list | None = None, | |
return_type: Literal["dict", "dataframe"] | None = "dataframe", | |
) -> pd.DataFrame | dict: | |
logger.info(f"columns: {columns}, tokenizer_filter: {tokenizer_filter}") | |
stats = {} | |
if columns is None: | |
columns = default_columns | |
if tokenizer_filter is not None: | |
tokenizer_names = [ | |
tokenizer_config.name_or_path | |
for tokenizer_config in tokenizer_factory.all_tokenizer_configs | |
if tokenizer_filter.lower() in tokenizer_config.name_or_path.lower() | |
] | |
else: | |
tokenizer_names = tokenizer_factory.all_tokenizer_names | |
for tokenizer_name in tokenizer_names: | |
stat = iter_vocab(tokenizer_name) | |
stats[tokenizer_name] = stat | |
if return_type == "dataframe": | |
stats = to_dataframe(stats, columns) | |
return stats | |
if __name__ == "__main__": | |
# aa = get_character_table(tokenizer_filter="baichuan") | |
df = get_character_table() | |
logger.info(f"\n{df.to_markdown(index=False)}") | |