import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import pandas as pd from functools import lru_cache import Tokens2Words # ---------------------------------------------------------------------- # IMPORTANT: This version uses the PatchscopesRetriever implementation # from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words) # ---------------------------------------------------------------------- try: from Tokens2Words.word_retriever import PatchscopesRetriever # pip install tokens2words except ImportError: PatchscopesRetriever = None DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere DEVICE = ( "cuda" if torch.cuda.is_available() else 'cpu' ) @lru_cache(maxsize=4) def get_model_and_tokenizer(model_name: str): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16 , output_hidden_states=True, ).to(DEVICE) model.eval() return model, tokenizer def find_last_token_index(full_ids, word_ids): """Locate end position of word_ids inside full_ids (first match).""" for i in range(len(full_ids) - len(word_ids) + 1): if full_ids[i : i + len(word_ids)] == word_ids: return i + len(word_ids) - 1 return None def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str): if PatchscopesRetriever is None: return ( "

❌ Patchscopes library not found. Run:
" "pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words

" ) model, tokenizer = get_model_and_tokenizer(model_name) # Build extraction prompt (where hidden states will be collected) extraction_prompt ="X" # Identify last token position of the *word* inside the prompt IDs word_token_ids = tokenizer.encode(word, add_special_tokens=False) # Instantiate Patchscopes retriever patch_retriever = PatchscopesRetriever( model, tokenizer, extraction_prompt, patchscopes_template, prompt_target_placeholder="X", ) # Run retrieval for the word across all layers (one pass) retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word( word, num_tokens_to_generate=len(tokenizer.tokenize(word)), )[0] # Build a table summarising which layers match records = [] matches = 0 for layer_idx, ret_word in enumerate(retrieved_words): match = ret_word.strip(" ") == word.strip(" ") if match: matches += 1 records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""}) df = pd.DataFrame(records) def _style(row): color = "background-color: lightgreen" if row["Match?"] else "" return [color] * len(row) html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False) sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids) top = ( f"

Sub‑word tokens: {' , '.join(sub_tokens)}

" f"

Total matched layers: {matches} / {len(retrieved_words)}

" ) return top + html_table # ----------------------------- GRADIO UI ------------------------------- with gr.Blocks(theme="soft") as demo: gr.Markdown( """# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers.""" ) with gr.Row(): model_name = gr.Dropdown( label="🤖 Model", choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b", "Qwen/Qwen2-7B"], value=DEFAULT_MODEL, ) extraction_template = gr.Textbox( label="Extraction prompt (use X as placeholder)", value="repeat the following word X twice: 1)X 2)", ) patchscopes_template = gr.Textbox( label="Patchscopes prompt (use X as placeholder)", value="repeat the following word X twice: 1)X 2)", ) word_box = gr.Textbox(label="Word to test", value="interpretable") run_btn = gr.Button("Analyse") out_html = gr.HTML() run_btn.click( analyse_word, inputs=[model_name, extraction_template, word_box, patchscopes_template], outputs=out_html, ) if __name__ == "__main__": demo.launch()