File size: 3,709 Bytes
bf8e518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02e5eda
bf8e518
 
 
 
 
 
 
 
 
02e5eda
 
bf8e518
 
 
 
02e5eda
bf8e518
 
 
 
 
 
 
 
 
 
02e5eda
bf8e518
 
 
 
86584f3
02e5eda
 
 
bf8e518
 
 
 
02e5eda
bf8e518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02e5eda
bf8e518
 
 
 
 
02e5eda
bf8e518
 
 
 
 
 
 
 
 
 
02e5eda
bf8e518
02e5eda
bf8e518
 
02e5eda
bf8e518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os

import pandas as pd
import streamlit as st
from transformers import AutoTokenizer

if os.getenv("SPACE_ID"):
    USE_HF_SPACE = True
    os.environ["HF_HOME"] = "/data/.huggingface"
    os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
    USE_HF_SPACE = False

DEFAULT_TOKENIZER_NAME = os.environ.get(
    "DEFAULT_TOKENIZER_NAME", "tohoku-nlp/bert-base-japanese-v3"
)
DEFAULT_TEXT = """
hello world!
こんにちは、世界!
你好,世界
""".strip()

DEFAULT_COLOR = "gray"
COLORS_CYCLE = [
    "yellow",
    "cyan",
]


def color_cycle_generator():
    def _color_cycle_generator():
        while True:
            for color in COLORS_CYCLE:
                yield color

    return _color_cycle_generator()


@st.cache_resource
def get_tokenizer(tokenizer_name: str = DEFAULT_TOKENIZER_NAME):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    return tokenizer


def main():
    st.set_page_config(
        page_title="TokenViz: AutoTokenizer Visualization Tool",
        layout="centered",
        initial_sidebar_state="auto",
    )

    st.title("TokenViz: AutoTokenizer Visualization Tool")
    st.text_input(
        "AutoTokenizer model name", key="tokenizer_name", value=DEFAULT_TOKENIZER_NAME
    )
    if st.session_state.tokenizer_name:
        tokenizer = get_tokenizer(st.session_state.tokenizer_name)
    st.text_input("subword prefix", key="subword_prefix", value="##")
    st.text_area("text", key="text", height=200, value=DEFAULT_TEXT)
    # Submit
    if st.button("tokenize"):
        text = st.session_state.text.strip()
        subword_prefix = st.session_state.subword_prefix.strip()
        token_ids = tokenizer.encode(text, add_special_tokens=True)
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        total_tokens = len(tokens)
        token_table_df = pd.DataFrame(
            {
                "token_id": token_ids,
                "token": tokens,
            }
        )

        st.subheader("visualized tokens")
        st.markdown(f"total tokens: **{total_tokens}**")
        tab_main, tab_token_table = st.tabs(["tokens", "table"])

        color_gen = color_cycle_generator()
        with tab_main:
            current_subword_color = next(color_gen)
            token_html = ""
            for idx, (token_id, token) in enumerate(zip(token_ids, tokens)):
                if len(subword_prefix) == 0:
                    token_border = f"1px solid {DEFAULT_COLOR}"
                else:
                    current_token_is_subword = token.startswith(subword_prefix)
                    next_token_is_subword = idx + 1 < total_tokens and tokens[
                        idx + 1
                    ].startswith(subword_prefix)

                    if next_token_is_subword and not current_token_is_subword:
                        current_subword_color = next(color_gen)

                    if current_token_is_subword or next_token_is_subword:
                        token_border = f"1px solid {current_subword_color}"
                    else:
                        token_border = f"1px solid {DEFAULT_COLOR}"

                html_escaped_token = token.replace("<", "&lt;").replace(">", "&gt;")
                token_html += f'<span title="{str(token_id)}" style="border: {token_border}; border-radius: 3px; padding: 2px; margin: 2px;">{html_escaped_token}</span>'
            st.html(
                f"<p style='line-height:2em;'>{token_html}</p>",
            )

            st.subheader("token_ids")

            token_ids_str = ",".join(map(str, token_ids))
            st.code(token_ids_str)

        with tab_token_table:
            st.table(token_table_df)


if __name__ == "__main__":
    main()