Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import pandas as pd | |
from functools import lru_cache | |
# ---------------------------------------------------------------------- | |
# IMPORTANT: This version uses the PatchscopesRetriever implementation | |
# from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words) | |
# ---------------------------------------------------------------------- | |
try: | |
from inner_lexicon.word_retriever import PatchscopesRetriever # pip install tokens2words | |
except ImportError: | |
PatchscopesRetriever = None | |
DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere | |
DEVICE = ( | |
"cuda" if torch.cuda.is_available() else 'cpu' | |
) | |
def get_model_and_tokenizer(model_name: str): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16 , | |
output_hidden_states=True, | |
).to(DEVICE) | |
model.eval() | |
return model, tokenizer | |
def find_last_token_index(full_ids, word_ids): | |
"""Locate end position of word_ids inside full_ids (first match).""" | |
for i in range(len(full_ids) - len(word_ids) + 1): | |
if full_ids[i : i + len(word_ids)] == word_ids: | |
return i + len(word_ids) - 1 | |
return None | |
def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str): | |
if PatchscopesRetriever is None: | |
return ( | |
"<p style='color:red'>❌ Patchscopes library not found. Run:<br/>" | |
"<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>" | |
) | |
model, tokenizer = get_model_and_tokenizer(model_name) | |
# Build extraction prompt (where hidden states will be collected) | |
extraction_prompt ="X" | |
# Identify last token position of the *word* inside the prompt IDs | |
word_token_ids = tokenizer.encode(word, add_special_tokens=False) | |
# Instantiate Patchscopes retriever | |
patch_retriever = PatchscopesRetriever( | |
model, | |
tokenizer, | |
extraction_prompt, | |
patchscopes_template, | |
prompt_target_placeholder="X", | |
) | |
# Run retrieval for the word across all layers (one pass) | |
retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word( | |
word, | |
num_tokens_to_generate=len(tokenizer.tokenize(word)), | |
)[0] | |
# Build a table summarising which layers match | |
records = [] | |
matches = 0 | |
for layer_idx, ret_word in enumerate(retrieved_words): | |
match = ret_word.strip(" ") == word.strip(" ") | |
if match: | |
matches += 1 | |
records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""}) | |
df = pd.DataFrame(records) | |
def _style(row): | |
color = "background-color: lightgreen" if row["Match?"] else "" | |
return [color] * len(row) | |
html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False) | |
sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids) | |
top = ( | |
f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>" | |
f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>" | |
) | |
return top + html_table | |
# ----------------------------- GRADIO UI ------------------------------- | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown( | |
"""# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers.""" | |
) | |
with gr.Row(): | |
model_name = gr.Dropdown( | |
label="🤖 Model", | |
choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b", "Qwen/Qwen2-7B"], | |
value=DEFAULT_MODEL, | |
) | |
extraction_template = gr.Textbox( | |
label="Extraction prompt (use X as placeholder)", | |
value="repeat the following word X twice: 1)X 2)", | |
) | |
patchscopes_template = gr.Textbox( | |
label="Patchscopes prompt (use X as placeholder)", | |
value="repeat the following word X twice: 1)X 2)", | |
) | |
word_box = gr.Textbox(label="Word to test", value="interpretable") | |
run_btn = gr.Button("Analyse") | |
out_html = gr.HTML() | |
run_btn.click( | |
analyse_word, | |
inputs=[model_name, extraction_template, word_box, patchscopes_template], | |
outputs=out_html, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |