import os import pandas as pd import streamlit as st from transformers import AutoTokenizer if os.getenv("SPACE_ID"): USE_HF_SPACE = True os.environ["HF_HOME"] = "/data/.huggingface" os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" else: USE_HF_SPACE = False DEFAULT_TOKENIZER_NAME = os.environ.get( "DEFAULT_TOKENIZER_NAME", "tohoku-nlp/bert-base-japanese-v3" ) DEFAULT_TEXT = """ hello world! こんにちは、世界! 你好,世界 """.strip() DEFAULT_COLOR = "gray" COLORS_CYCLE = [ "yellow", "cyan", ] def color_cycle_generator(): def _color_cycle_generator(): while True: for color in COLORS_CYCLE: yield color return _color_cycle_generator() @st.cache_resource def get_tokenizer(tokenizer_name: str = DEFAULT_TOKENIZER_NAME): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) return tokenizer def main(): st.set_page_config( page_title="TokenViz: AutoTokenizer Visualization Tool", layout="centered", initial_sidebar_state="auto", ) st.title("TokenViz: AutoTokenizer Visualization Tool") st.text_input( "AutoTokenizer model name", key="tokenizer_name", value=DEFAULT_TOKENIZER_NAME ) if st.session_state.tokenizer_name: tokenizer = get_tokenizer(st.session_state.tokenizer_name) st.text_input("subword prefix", key="subword_prefix", value="##") st.text_area("text", key="text", height=200, value=DEFAULT_TEXT) # Submit if st.button("tokenize"): text = st.session_state.text.strip() subword_prefix = st.session_state.subword_prefix.strip() token_ids = tokenizer.encode(text, add_special_tokens=True) tokens = tokenizer.convert_ids_to_tokens(token_ids) total_tokens = len(tokens) token_table_df = pd.DataFrame( { "token_id": token_ids, "token": tokens, } ) st.subheader("visualized tokens") st.markdown(f"total tokens: **{total_tokens}**") tab_main, tab_token_table = st.tabs(["tokens", "table"]) color_gen = color_cycle_generator() with tab_main: current_subword_color = next(color_gen) token_html = "" for idx, (token_id, token) in enumerate(zip(token_ids, tokens)): if len(subword_prefix) == 0: token_border = f"1px solid {DEFAULT_COLOR}" else: current_token_is_subword = token.startswith(subword_prefix) next_token_is_subword = idx + 1 < total_tokens and tokens[ idx + 1 ].startswith(subword_prefix) if next_token_is_subword and not current_token_is_subword: current_subword_color = next(color_gen) if current_token_is_subword or next_token_is_subword: token_border = f"1px solid {current_subword_color}" else: token_border = f"1px solid {DEFAULT_COLOR}" html_escaped_token = token.replace("<", "<").replace(">", ">") token_html += f'{html_escaped_token}' st.html( f"
{token_html}
", ) st.subheader("token_ids") token_ids_str = ",".join(map(str, token_ids)) st.code(token_ids_str) with tab_token_table: st.table(token_table_df) if __name__ == "__main__": main()