File size: 4,528 Bytes
b7e1c46
 
0f64adf
b7e1c46
 
0f64adf
b7e1c46
 
 
 
 
9e67a8e
b7e1c46
 
0f64adf
b7e1c46
af9dbe3
 
 
b7e1c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import pandas as pd
from functools import lru_cache

# ----------------------------------------------------------------------
# IMPORTANT: This version uses the PatchscopesRetriever implementation
# from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words)
# ----------------------------------------------------------------------
try:
    from inner_lexicon.word_retriever import PatchscopesRetriever  # pip install tokens2words
except ImportError:
    PatchscopesRetriever = None

DEFAULT_MODEL = "meta-llama/Llama-3.1-8B"  # light default so the demo boots everywhere
DEVICE = (
    "cuda" if torch.cuda.is_available() else 'cpu'
)

@lru_cache(maxsize=4)
def get_model_and_tokenizer(model_name: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 ,
        output_hidden_states=True,
    ).to(DEVICE)
    model.eval()
    return model, tokenizer


def find_last_token_index(full_ids, word_ids):
    """Locate end position of word_ids inside full_ids (first match)."""
    for i in range(len(full_ids) - len(word_ids) + 1):
        if full_ids[i : i + len(word_ids)] == word_ids:
            return i + len(word_ids) - 1
    return None


def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
    if PatchscopesRetriever is None:
        return (
            "<p style='color:red'>❌ Patchscopes library not found. Run:<br/>"
            "<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>"
        )

    model, tokenizer = get_model_and_tokenizer(model_name)

    # Build extraction prompt (where hidden states will be collected)
    extraction_prompt ="X"

    # Identify last token position of the *word* inside the prompt IDs
    word_token_ids = tokenizer.encode(word, add_special_tokens=False)

    # Instantiate Patchscopes retriever
    patch_retriever = PatchscopesRetriever(
        model,
        tokenizer,
        extraction_prompt,
        patchscopes_template,
        prompt_target_placeholder="X",
    )

    # Run retrieval for the word across all layers (one pass)
    retrieved_words  = patch_retriever.get_hidden_states_and_retrieve_word(
        word,
        num_tokens_to_generate=len(tokenizer.tokenize(word)),
    )[0]

    # Build a table summarising which layers match
    records = []
    matches = 0
    for layer_idx, ret_word in enumerate(retrieved_words):
        match = ret_word.strip(" ") == word.strip(" ")
        if match:
            matches += 1
        records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""})

    df = pd.DataFrame(records)

    def _style(row):
        color = "background-color: lightgreen" if row["Match?"] else ""
        return [color] * len(row)

    html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False)

    sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids)
    top = (
        f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>"
        f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>"
    )
    return top + html_table


# ----------------------------- GRADIO UI -------------------------------
with gr.Blocks(theme="soft") as demo:
    gr.Markdown(
        """# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers."""
    )

    with gr.Row():
        model_name = gr.Dropdown(
            label="🤖 Model",
            choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b", "Qwen/Qwen2-7B"],
            value=DEFAULT_MODEL,
        )
        extraction_template = gr.Textbox(
            label="Extraction prompt (use X as placeholder)",
            value="repeat the following word X twice: 1)X 2)",
        )
    patchscopes_template = gr.Textbox(
        label="Patchscopes prompt (use X as placeholder)",
        value="repeat the following word X twice: 1)X 2)",
    )
    word_box = gr.Textbox(label="Word to test", value="interpretable")
    run_btn = gr.Button("Analyse")
    out_html = gr.HTML()

    run_btn.click(
        analyse_word,
        inputs=[model_name, extraction_template, word_box, patchscopes_template],
        outputs=out_html,
    )

if __name__ == "__main__":
    demo.launch()