Spaces:
Sleeping
Sleeping
File size: 3,709 Bytes
bf8e518 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 86584f3 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 02e5eda bf8e518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import pandas as pd
import streamlit as st
from transformers import AutoTokenizer
if os.getenv("SPACE_ID"):
USE_HF_SPACE = True
os.environ["HF_HOME"] = "/data/.huggingface"
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
USE_HF_SPACE = False
DEFAULT_TOKENIZER_NAME = os.environ.get(
"DEFAULT_TOKENIZER_NAME", "tohoku-nlp/bert-base-japanese-v3"
)
DEFAULT_TEXT = """
hello world!
こんにちは、世界!
你好,世界
""".strip()
DEFAULT_COLOR = "gray"
COLORS_CYCLE = [
"yellow",
"cyan",
]
def color_cycle_generator():
def _color_cycle_generator():
while True:
for color in COLORS_CYCLE:
yield color
return _color_cycle_generator()
@st.cache_resource
def get_tokenizer(tokenizer_name: str = DEFAULT_TOKENIZER_NAME):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
return tokenizer
def main():
st.set_page_config(
page_title="TokenViz: AutoTokenizer Visualization Tool",
layout="centered",
initial_sidebar_state="auto",
)
st.title("TokenViz: AutoTokenizer Visualization Tool")
st.text_input(
"AutoTokenizer model name", key="tokenizer_name", value=DEFAULT_TOKENIZER_NAME
)
if st.session_state.tokenizer_name:
tokenizer = get_tokenizer(st.session_state.tokenizer_name)
st.text_input("subword prefix", key="subword_prefix", value="##")
st.text_area("text", key="text", height=200, value=DEFAULT_TEXT)
# Submit
if st.button("tokenize"):
text = st.session_state.text.strip()
subword_prefix = st.session_state.subword_prefix.strip()
token_ids = tokenizer.encode(text, add_special_tokens=True)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
total_tokens = len(tokens)
token_table_df = pd.DataFrame(
{
"token_id": token_ids,
"token": tokens,
}
)
st.subheader("visualized tokens")
st.markdown(f"total tokens: **{total_tokens}**")
tab_main, tab_token_table = st.tabs(["tokens", "table"])
color_gen = color_cycle_generator()
with tab_main:
current_subword_color = next(color_gen)
token_html = ""
for idx, (token_id, token) in enumerate(zip(token_ids, tokens)):
if len(subword_prefix) == 0:
token_border = f"1px solid {DEFAULT_COLOR}"
else:
current_token_is_subword = token.startswith(subword_prefix)
next_token_is_subword = idx + 1 < total_tokens and tokens[
idx + 1
].startswith(subword_prefix)
if next_token_is_subword and not current_token_is_subword:
current_subword_color = next(color_gen)
if current_token_is_subword or next_token_is_subword:
token_border = f"1px solid {current_subword_color}"
else:
token_border = f"1px solid {DEFAULT_COLOR}"
html_escaped_token = token.replace("<", "<").replace(">", ">")
token_html += f'<span title="{str(token_id)}" style="border: {token_border}; border-radius: 3px; padding: 2px; margin: 2px;">{html_escaped_token}</span>'
st.html(
f"<p style='line-height:2em;'>{token_html}</p>",
)
st.subheader("token_ids")
token_ids_str = ",".join(map(str, token_ids))
st.code(token_ids_str)
with tab_token_table:
st.table(token_table_df)
if __name__ == "__main__":
main()
|