Last commit not found
import os | |
import streamlit as st | |
from huggingface_hub import login | |
from transformers import AutoTokenizer | |
st.set_page_config(layout="wide") | |
token = os.environ.get("hf_token") | |
login(token=token) | |
class TokenizationVisualizer: | |
def __init__(self): | |
self.tokenizers = {} | |
def add_tokenizer(self, name, model_name): | |
self.tokenizers[name] = AutoTokenizer.from_pretrained(model_name) | |
def visualize_tokens(self, text, tokenizer): | |
tokens = tokenizer.tokenize(text) | |
str_tokens = [] | |
for token in tokens: | |
str_tokens.append(tokenizer.convert_tokens_to_string([token])) | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
colors = ['#ffdab9', '#e6ee9c', '#9cddc8', '#bcaaa4', '#c5b0d5'] | |
html = "" | |
for i, token in enumerate(str_tokens): | |
color = colors[i % len(colors)] | |
html += f'<mark title="{token}" style="background-color: {color};">{token}</mark>' | |
return html, token_ids | |
def playground_tab(visualizer): | |
st.title("Tokenization Visualizer for Language Models") | |
st.markdown(""" | |
You can use this playground to visualize Llama2 tokens & Gujarati Llama tokens generated by the tokenizers. | |
""") | |
text_input = st.text_area("Enter text below to visualize tokens:", height=300) | |
if st.button("Tokenize"): | |
st.divider() | |
if text_input.strip(): | |
llama_tokenization_results, llama_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Llama2"]) | |
gujju_tokenization_results, gujju_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Gujju Llama"]) | |
col1, col2 = st.columns(2) | |
col1.title('Llama2 Tokenizer') | |
col1.container(height=200, border=True).markdown(llama_tokenization_results, unsafe_allow_html=True) | |
with col1.expander(f"Token IDs (Token Counts = {len(llama_token_ids)})"): | |
st.markdown(llama_token_ids) | |
col2.title('Gujju Llama Tokenizer') | |
col2.container(height=200, border=True).markdown(gujju_tokenization_results, unsafe_allow_html=True) | |
with col2.expander(f"Token IDs (Token Counts = {len(gujju_token_ids)})"): | |
st.markdown(gujju_token_ids) | |
else: | |
st.error("Please enter some text.") | |
def main(): | |
huggingface_tokenizers ={ | |
"Gujju Llama": "sampoorna42/Gujju-Llama-Instruct-v0.1", | |
"Llama2": "meta-llama/Llama-2-7b-hf", | |
} | |
visualizer = TokenizationVisualizer() | |
for tokenizer, src in huggingface_tokenizers.items(): | |
visualizer.add_tokenizer(tokenizer, src) | |
playground_tab(visualizer) | |
if __name__ == "__main__": | |
main() | |