import html import logging from pathlib import Path import gradio as gr from gradio.themes.utils import colors from transformers import CLIPTokenizer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) gr_logger = logging.getLogger("gradio") gr_logger.setLevel(logging.INFO) class ClipUtil: def __init__(self): logger.info("Loading ClipUtil") self.theme = gr.themes.Base( primary_hue=colors.violet, secondary_hue=colors.indigo, neutral_hue=colors.slate, font=[gr.themes.GoogleFont("Fira Sans"), "ui-sans-serif", "system-ui", "sans-serif"], font_mono=[gr.themes.GoogleFont("Fira Code"), "ui-monospace", "Consolas", "monospace"], ).set( slider_color_dark="*primary_500", ) try: self.css = Path(__file__).with_suffix(".css").read_text() except Exception: logger.exception("Failed to load CSS file") self.css = "" self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.vocab = {v: k for k, v in self.tokenizer.get_vocab().items()} self.blocks = gr.Blocks( title="ClipTokenizerUtil", analytics_enabled=False, theme=self.theme, css=self.css ) def tokenize(self, text: str, input_ids: bool = False): if input_ids: tokens = [int(x.strip()) for x in text.split(",")] else: tokens = self.tokenizer(text, return_tensors="np").input_ids.squeeze().tolist() code = "" ids = [] current_ids = [] class_index = 0 byte_decoder = self.tokenizer.byte_decoder def dump(last=False): nonlocal code, ids, current_ids words = [self.vocab.get(x, "") for x in current_ids] def wordscode(ids, word): nonlocal class_index word_title = html.escape(", ".join([str(x) for x in ids])) res = f""" {html.escape(word)} """ class_index += 1 return res try: word = bytearray([byte_decoder[x] for x in "".join(words)]).decode("utf-8") except UnicodeDecodeError: if last: word = "❌" * len(current_ids) elif len(current_ids) > 4: id = current_ids[0] ids += [id] local_ids = current_ids[1:] code += wordscode([id], "❌") current_ids = [] for id in local_ids: current_ids.append(id) dump() return else: return # word = word.replace("", " ") code += wordscode(current_ids, word) ids += current_ids current_ids = [] for token in tokens: token = int(token) current_ids.append(token) dump() dump(last=True) ids_html = f"""

Token count: {len(ids)}
{", ".join([str(x) for x in ids])}

""" return code, ids_html def tokenize_ids(self, text: str): return self.tokenize(text, input_ids=True) def create_components(self): with self.blocks: # title bar with gr.Row().style(equal_height=True): with gr.Column(scale=12, elem_id="header_col"): self.header_title = gr.Markdown( "## CLIP Tokenizer Util", elem_id="header_title", ) with gr.Column(scale=1, min_width=90, elem_id="button_col"): with gr.Row(elem_id="button_row"): self.reload_btn = gr.Button( label="refresh", elem_id="refresh_btn", type="button", value="🔄", variant="primary", ) with gr.Tabs() as in_tabs: with gr.Tab(label="Text Input", id="text_input_tab"): with gr.Row().style(equal_height=True): with gr.Column(scale=12, elem_id="text_input_col"): self.text_input = gr.Textbox( label="Text Input", elem_id="tokenizer_prompt", show_label=False, lines=8, placeholder="Prompt for tokenization", ) self.text_button = gr.Button( label="Tokenize", elem_id="go_button", value="Go", variant="primary", ) with gr.Tab(label="Token Input", id="token_input_tab"): with gr.Row().style(equal_height=True): with gr.Column(scale=12, elem_id="text_input_col"): self.token_input = gr.Textbox( lines=5, label="Text Input", elem_id="text_input", placeholder="Enter text here", ) self.token_button = gr.Button( label="Tokenize", elem_id="go_button", type="button", value="Go", variant="primary", ) with gr.Tabs(): with gr.Tab("Text"): tokenized_text = gr.HTML(elem_id="tokenized_text") with gr.Tab("Tokens"): tokenized_ids = gr.HTML(elem_id="tokenized_ids") self.text_button.click( fn=self.tokenize, inputs=[self.text_input], outputs=[tokenized_text, tokenized_ids], ) self.token_button.click( fn=self.tokenize_ids, inputs=[self.token_input], outputs=[tokenized_text, tokenized_ids], ) def launch(self, **kwargs) -> None: return self.blocks.launch( server_name="0.0.0.0", show_error=True, enable_queue=True, **kwargs, ) if __name__ == "__main__": clip_util = ClipUtil() clip_util.create_components() clip_util.launch()