Guy24 commited on
Commit
b7e1c46
·
1 Parent(s): 0f64adf

adding application

Browse files
Files changed (3) hide show
  1. app.py +125 -4
  2. requirements.txt +213 -0
  3. word_retriever.py +189 -0
app.py CHANGED
@@ -1,7 +1,128 @@
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
+ import pandas as pd
5
+ from functools import lru_cache
6
 
7
+ # ----------------------------------------------------------------------
8
+ # IMPORTANT: This version uses the PatchscopesRetriever implementation
9
+ # from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words)
10
+ # ----------------------------------------------------------------------
11
+ try:
12
+ from word_retriever import PatchscopesRetriever # pip install tokens2words
13
+ except ImportError:
14
+ PatchscopesRetriever = None
15
 
16
+ DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere
17
+ DEVICE = 'mps'
18
+ # (
19
+ # "cuda" if torch.cuda.is_available() else ("mps" if torch.word_retriever.pybackends.mps.is_available() else "cpu")
20
+ # )
21
+
22
+ @lru_cache(maxsize=4)
23
+ def get_model_and_tokenizer(model_name: str):
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.bfloat16 ,
28
+ output_hidden_states=True,
29
+ ).to(DEVICE)
30
+ model.eval()
31
+ return model, tokenizer
32
+
33
+
34
+ def find_last_token_index(full_ids, word_ids):
35
+ """Locate end position of word_ids inside full_ids (first match)."""
36
+ for i in range(len(full_ids) - len(word_ids) + 1):
37
+ if full_ids[i : i + len(word_ids)] == word_ids:
38
+ return i + len(word_ids) - 1
39
+ return None
40
+
41
+
42
+ def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
43
+ if PatchscopesRetriever is None:
44
+ return (
45
+ "<p style='color:red'>❌ Patchscopes library not found. Run:<br/>"
46
+ "<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>"
47
+ )
48
+
49
+ model, tokenizer = get_model_and_tokenizer(model_name)
50
+
51
+ # Build extraction prompt (where hidden states will be collected)
52
+ extraction_prompt ="X"
53
+
54
+ # Identify last token position of the *word* inside the prompt IDs
55
+ word_token_ids = tokenizer.encode(word, add_special_tokens=False)
56
+
57
+ # Instantiate Patchscopes retriever
58
+ patch_retriever = PatchscopesRetriever(
59
+ model,
60
+ tokenizer,
61
+ extraction_prompt,
62
+ patchscopes_template,
63
+ prompt_target_placeholder="X",
64
+ )
65
+
66
+ # Run retrieval for the word across all layers (one pass)
67
+ retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
68
+ word,
69
+ num_tokens_to_generate=len(tokenizer.tokenize(word)),
70
+ )[0]
71
+
72
+ # Build a table summarising which layers match
73
+ records = []
74
+ matches = 0
75
+ for layer_idx, ret_word in enumerate(retrieved_words):
76
+ match = ret_word.strip(" ") == word.strip(" ")
77
+ if match:
78
+ matches += 1
79
+ records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""})
80
+
81
+ df = pd.DataFrame(records)
82
+
83
+ def _style(row):
84
+ color = "background-color: lightgreen" if row["Match?"] else ""
85
+ return [color] * len(row)
86
+
87
+ html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False)
88
+
89
+ sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids)
90
+ top = (
91
+ f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>"
92
+ f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>"
93
+ )
94
+ return top + html_table
95
+
96
+
97
+ # ----------------------------- GRADIO UI -------------------------------
98
+ with gr.Blocks(theme="soft") as demo:
99
+ gr.Markdown(
100
+ """# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers."""
101
+ )
102
+
103
+ with gr.Row():
104
+ model_name = gr.Dropdown(
105
+ label="🤖 Model",
106
+ choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b", "Qwen/Qwen2-7B"],
107
+ value=DEFAULT_MODEL,
108
+ )
109
+ extraction_template = gr.Textbox(
110
+ label="Extraction prompt (use X as placeholder)",
111
+ value="repeat the following word X twice: 1)X 2)",
112
+ )
113
+ patchscopes_template = gr.Textbox(
114
+ label="Patchscopes prompt (use X as placeholder)",
115
+ value="repeat the following word X twice: 1)X 2)",
116
+ )
117
+ word_box = gr.Textbox(label="Word to test", value="interpretable")
118
+ run_btn = gr.Button("Analyse")
119
+ out_html = gr.HTML()
120
+
121
+ run_btn.click(
122
+ analyse_word,
123
+ inputs=[model_name, extraction_template, word_box, patchscopes_template],
124
+ outputs=out_html,
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.1
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.4
4
+ aiohttp==3.11.11
5
+ aiosignal==1.3.2
6
+ annotated-types==0.7.0
7
+ anyio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_68kdsx8iyd/croot/anyio_1729121281958/work
8
+ appnope @ file:///Users/ktietz/demo/mc3/conda-bld/appnope_1629146036738/work
9
+ argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work
10
+ argon2-cffi-bindings @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2ef471wnyf/croot/argon2-cffi-bindings_1736182451265/work
11
+ asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
12
+ async-lru @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_02efro5ps8/croot/async-lru_1699554529181/work
13
+ async-timeout==5.0.1
14
+ attrs @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93pjmt0git/croot/attrs_1734533120523/work
15
+ Babel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_00k1rl2pus/croot/babel_1671781944131/work
16
+ backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
17
+ beautifulsoup4 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_94rx5n7wo9/croot/beautifulsoup4-split_1718029832430/work
18
+ bleach @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_faqg19k8gh/croot/bleach_1732292152791/work
19
+ blis==1.2.0
20
+ Brotli @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f7i0oxypt6/croot/brotli-split_1736182464088/work
21
+ catalogue==2.0.10
22
+ certifi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d8j59rqun5/croot/certifi_1734473289913/work/certifi
23
+ cffi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e4xd9yd9i2/croot/cffi_1736182819442/work
24
+ charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
25
+ click==8.1.8
26
+ cloudpathlib==0.20.0
27
+ cloudpickle==3.1.0
28
+ comm @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_3doui0bmzb/croot/comm_1709322861485/work
29
+ confection==0.1.5
30
+ contourpy==1.3.0
31
+ cycler==0.12.1
32
+ cymem==2.0.11
33
+ dask==2024.8.0
34
+ dask-expr==1.1.10
35
+ datasets==3.2.0
36
+ debugpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_563_nwtkoc/croot/debugpy_1690905063850/work
37
+ decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
38
+ defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
39
+ -e git+https://github.com/tokeron/diffusers.git@00769b5d64c2ea35201e0df7a082db3513619afe#egg=diffusers&subdirectory=../../../../../../diffusers
40
+ dill==0.3.8
41
+ distro==1.9.0
42
+ docker-pycreds==0.4.0
43
+ editdistance==0.8.1
44
+ en_core_web_lg @ https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl#sha256=293e9547a655b25499198ab15a525b05b9407a75f10255e405e8c3854329ab63
45
+ en_core_web_md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl#sha256=5e6329fe3fecedb1d1a02c3ea2172ee0fede6cea6e4aefb6a02d832dba78a310
46
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
47
+ eval_type_backport==0.2.2
48
+ exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work
49
+ executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
50
+ fastapi==0.115.12
51
+ fastjsonschema @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d1wgyi4enb/croot/python-fastjsonschema_1731939426145/work
52
+ ffmpy==0.5.0
53
+ filelock==3.16.1
54
+ fonttools==4.55.3
55
+ frozenlist==1.5.0
56
+ fsspec==2024.9.0
57
+ gitdb==4.0.12
58
+ GitPython==3.1.44
59
+ gradio==4.44.1
60
+ gradio_client==1.3.0
61
+ h11 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_110bmw2coo/croot/h11_1706652289620/work
62
+ httpcore @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_fcxiho9nv7/croot/httpcore_1706728465004/work
63
+ httpx @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cc4egw1482/croot/httpx_1723474826664/work
64
+ huggingface-hub==0.27.1
65
+ idna==3.10
66
+ importlib_metadata @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_cc4qelzghy/croot/importlib_metadata-suite_1732633706960/work
67
+ importlib_resources==6.5.2
68
+ ipykernel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ddflobe9t3/croot/ipykernel_1728665605034/work
69
+ ipython @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_6599f73fa7/croot/ipython_1694181355402/work
70
+ jedi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_38ctoinnl0/croot/jedi_1733987402850/work
71
+ Jinja2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b15nuwux5r/croot/jinja2_1730902833938/work
72
+ jiter==0.8.2
73
+ joblib==1.4.2
74
+ json5 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b9ww6ewhv3/croot/json5_1730786813588/work
75
+ jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_7boelfqucq/croot/jsonschema_1728486715888/work
76
+ jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work
77
+ jupyter-events @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_db0avcjzq5/croot/jupyter_events_1718738111427/work
78
+ jupyter-lsp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_ae9br5v37x/croot/jupyter-lsp-meta_1699978259353/work
79
+ jupyter_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_58w2siozyz/croot/jupyter_client_1699455907045/work
80
+ jupyter_core @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_73nomeum4p/croot/jupyter_core_1718818302815/work
81
+ jupyter_server @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d1t69bk94b/croot/jupyter_server_1718827086930/work
82
+ jupyter_server_terminals @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e7ryd60iuw/croot/jupyter_server_terminals_1686870731283/work
83
+ jupyterlab @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a2d0br6r6g/croot/jupyterlab_1725895226942/work
84
+ jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
85
+ jupyterlab_server @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f64fg3hglz/croot/jupyterlab_server_1725865356410/work
86
+ kiwisolver==1.4.7
87
+ langcodes==3.5.0
88
+ language_data==1.3.0
89
+ locket==1.0.0
90
+ marisa-trie==1.2.1
91
+ markdown-it-py==3.0.0
92
+ MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
93
+ matplotlib==3.9.4
94
+ matplotlib-inline @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f6fdc0hldi/croots/recipe/matplotlib-inline_1662014472341/work
95
+ matplotlib-venn==1.1.2
96
+ mdurl==0.1.2
97
+ mistune @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_17ya6k1sbs/croots/recipe/mistune_1661496228719/work
98
+ mpmath==1.3.0
99
+ multidict==6.1.0
100
+ multiprocess==0.70.16
101
+ murmurhash==1.0.12
102
+ nbclient @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_626hpwnurm/croot/nbclient_1698934218848/work
103
+ nbconvert @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f4c1s1qk1f/croot/nbconvert_1728049432295/work
104
+ nbformat @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2cv_qoc1gw/croot/nbformat_1728049423516/work
105
+ nest-asyncio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_310vb5e2a0/croot/nest-asyncio_1708532678212/work
106
+ networkx==3.2.1
107
+ nltk==3.9.1
108
+ notebook @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_539v4hufo2/croot/notebook_1727199149603/work
109
+ notebook_shim @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d6_ze10f45/croot/notebook-shim_1699455897525/work
110
+ numpy==2.0.2
111
+ openai==1.59.7
112
+ orjson==3.10.16
113
+ overrides @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70s80guh9g/croot/overrides_1699371144462/work
114
+ packaging @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a6_qk3qyg7/croot/packaging_1734472142254/work
115
+ pandas==2.2.3
116
+ pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
117
+ parso @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_8824a1w4md/croot/parso_1733963320105/work
118
+ partd==1.4.2
119
+ patsy==1.0.1
120
+ pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
121
+ pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
122
+ pillow==10.4.0
123
+ platformdirs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8u4fy8k9o/croot/platformdirs_1692205661656/work
124
+ plotly==5.24.1
125
+ preshed==3.0.9
126
+ prometheus_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_803ymjpv2u/croot/prometheus_client_1731958793251/work
127
+ prompt-toolkit @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_c63v4kqjzr/croot/prompt-toolkit_1704404354115/work
128
+ propcache==0.2.1
129
+ protobuf==5.29.2
130
+ psutil @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1310b568-21f4-4cb0-b0e3-2f3d31e39728k9coaga5/croots/recipe/psutil_1656431280844/work
131
+ ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
132
+ pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
133
+ pyarrow==18.1.0
134
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
135
+ pydantic==2.10.4
136
+ pydantic_core==2.27.2
137
+ pydub==0.25.1
138
+ Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work
139
+ pyparsing==3.2.1
140
+ PySocks @ file:///Users/ktietz/Code/oss/ci_pkgs/pysocks_1626781349491/work
141
+ python-box==7.3.0
142
+ python-dateutil @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_66ud1l42_h/croot/python-dateutil_1716495741162/work
143
+ python-json-logger @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9bjmcmh4nm/croot/python-json-logger_1734370248301/work
144
+ python-multipart==0.0.20
145
+ pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a4b76c83ik/croot/pytz_1713974318928/work
146
+ PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_faoex52hrr/croot/pyyaml_1728657970485/work
147
+ pyzmq @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_95lsut8ymz/croot/pyzmq_1734709560733/work
148
+ referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work
149
+ regex==2024.11.6
150
+ requests @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ee45nsd33z/croot/requests_1730999134038/work
151
+ rfc3339-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_76ae5cu30h/croot/rfc3339-validator_1683077051957/work
152
+ rfc3986-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0l5zd97kt/croot/rfc3986-validator_1683058998431/work
153
+ rich==13.9.4
154
+ rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_93fzmr7v9h/croot/rpds-py_1732228422522/work
155
+ ruff==0.11.6
156
+ safetensors==0.5.0
157
+ scikit-learn==1.6.0
158
+ scipy==1.13.1
159
+ seaborn==0.13.2
160
+ semantic-version==2.10.0
161
+ Send2Trash @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5b31f0zzlv/croot/send2trash_1699371144121/work
162
+ sentencepiece==0.2.0
163
+ sentry-sdk==2.19.2
164
+ setproctitle==1.3.4
165
+ shellingham==1.5.4
166
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
167
+ smart-open==7.1.0
168
+ smmap==5.0.2
169
+ sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work
170
+ soupsieve @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9798xzs_03/croot/soupsieve_1696347567192/work
171
+ spacy==3.8.3
172
+ spacy-legacy==3.0.12
173
+ spacy-loggers==1.0.5
174
+ srsly==2.5.1
175
+ stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
176
+ starlette==0.46.2
177
+ statsmodels==0.14.4
178
+ swifter==1.4.0
179
+ sympy==1.13.1
180
+ tabulate==0.9.0
181
+ tenacity==9.0.0
182
+ terminado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcfvyc0an2/croot/terminado_1671751835701/work
183
+ thinc==8.3.4
184
+ threadpoolctl==3.5.0
185
+ tinycss2 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcw5_i306t/croot/tinycss2_1668168825117/work
186
+ together==1.4.1
187
+ tokenizers==0.21.0
188
+ tomli @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0e5ffbf-5cf1-45be-8693-c5dff8108a2awhthtjlq/croots/recipe/tomli_1657175508477/work
189
+ tomlkit==0.12.0
190
+ toolz==1.0.0
191
+ torch==2.5.1
192
+ torchvision==0.20.1
193
+ tornado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_0axef5a0m0/croot/tornado_1733960501260/work
194
+ tqdm==4.67.1
195
+ traitlets @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_500m2_1wyk/croot/traitlets_1718227071952/work
196
+ transformers==4.47.1
197
+ typer==0.15.1
198
+ typing_extensions @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_0b3jpv_f79/croot/typing_extensions_1734714864260/work
199
+ tzdata==2024.2
200
+ urllib3 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_06_m8gdsy6/croot/urllib3_1727769822458/work
201
+ uvicorn==0.34.2
202
+ wandb==0.19.1
203
+ wasabi==1.1.3
204
+ wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
205
+ weasel==0.4.1
206
+ webencodings==0.5.1
207
+ websocket-client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d37u7gqts8/croot/websocket-client_1715878310260/work
208
+ websockets==12.0
209
+ wordcloud==1.9.4
210
+ wrapt==1.17.2
211
+ xxhash==3.5.0
212
+ yarl==1.18.3
213
+ zipp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_echurpkwug/croot/zipp_1732630743967/work
word_retriever.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from abc import ABC, abstractmethod
4
+
5
+ from .utils.enums import MultiTokenKind, RetrievalTechniques
6
+ from .processor import RetrievalProcessor
7
+ from .utils.logit_lens import ReverseLogitLens
8
+ from .utils.model_utils import extract_token_i_hidden_states
9
+
10
+
11
+ class WordRetrieverBase(ABC):
12
+ def __init__(self, model, tokenizer):
13
+ self.model = model
14
+ self.tokenizer = tokenizer
15
+
16
+ @abstractmethod
17
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
18
+ pass
19
+
20
+
21
+ class PatchscopesRetriever(WordRetrieverBase):
22
+ def __init__(
23
+ self,
24
+ model,
25
+ tokenizer,
26
+ representation_prompt: str = "{word}",
27
+ patchscopes_prompt: str = "Next is the same word twice: 1) {word} 2)",
28
+ prompt_target_placeholder: str = "{word}",
29
+ representation_token_idx_to_extract: int = -1,
30
+ num_tokens_to_generate: int = 10,
31
+ ):
32
+ super().__init__(model, tokenizer)
33
+ self.prompt_input_ids, self.prompt_target_idx = \
34
+ self._build_prompt_input_ids_template(patchscopes_prompt, prompt_target_placeholder)
35
+ self._prepare_representation_prompt = \
36
+ self._build_representation_prompt_func(representation_prompt, prompt_target_placeholder)
37
+ self.representation_token_idx = representation_token_idx_to_extract
38
+ self.num_tokens_to_generate = num_tokens_to_generate
39
+
40
+ def _build_prompt_input_ids_template(self, prompt, target_placeholder):
41
+ prompt_input_ids = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id is not None else []
42
+ target_idx = []
43
+
44
+ if prompt:
45
+ assert target_placeholder is not None, \
46
+ "Trying to set a prompt for Patchscopes without defining the prompt's target placeholder string, e.g., [MASK]"
47
+
48
+ prompt_parts = prompt.split(target_placeholder)
49
+ for part_i, prompt_part in enumerate(prompt_parts):
50
+ prompt_input_ids += self.tokenizer.encode(prompt_part, add_special_tokens=False)
51
+ if part_i < len(prompt_parts)-1:
52
+ target_idx += [len(prompt_input_ids)]
53
+ prompt_input_ids += [0]
54
+ else:
55
+ prompt_input_ids += [0]
56
+ target_idx = [len(prompt_input_ids)]
57
+
58
+ prompt_input_ids = torch.tensor(prompt_input_ids, dtype=torch.long)
59
+ target_idx = torch.tensor(target_idx, dtype=torch.long)
60
+ return prompt_input_ids, target_idx
61
+
62
+ def _build_representation_prompt_func(self, prompt, target_placeholder):
63
+ return lambda word: prompt.replace(target_placeholder, word)
64
+
65
+ def generate_states(self, tokenizer, word='Wakanda', with_prompt=True):
66
+ prompt = self.generate_prompt() if with_prompt else word
67
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
68
+ return input_ids
69
+
70
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=None):
71
+ self.model.eval()
72
+
73
+ # insert hidden states into patchscopes prompt
74
+ if hidden_states.dim() == 1:
75
+ hidden_states = hidden_states.unsqueeze(0)
76
+
77
+ inputs_embeds = self.model.get_input_embeddings()(self.prompt_input_ids.to(self.model.device)).unsqueeze(0)
78
+ batched_patchscope_inputs = inputs_embeds.repeat(len(hidden_states), 1, 1).to(hidden_states.dtype)
79
+ batched_patchscope_inputs[:, self.prompt_target_idx] = hidden_states.unsqueeze(1).to(self.model.device)
80
+
81
+ attention_mask = (self.prompt_input_ids != self.tokenizer.eos_token_id).long().unsqueeze(0).repeat(
82
+ len(hidden_states), 1).to(self.model.device)
83
+
84
+ num_tokens_to_generate = num_tokens_to_generate if num_tokens_to_generate else self.num_tokens_to_generate
85
+
86
+ with torch.no_grad():
87
+ patchscope_outputs = self.model.generate(
88
+ do_sample=False, num_beams=1, top_p=1.0, temperature=None,
89
+ inputs_embeds=batched_patchscope_inputs,# attention_mask=attention_mask,
90
+ max_new_tokens=num_tokens_to_generate, pad_token_id=self.tokenizer.eos_token_id, )
91
+
92
+ decoded_patchscope_outputs = self.tokenizer.batch_decode(patchscope_outputs)
93
+ return decoded_patchscope_outputs
94
+
95
+ def extract_hidden_states(self, word):
96
+ representation_input = self._prepare_representation_prompt(word)
97
+
98
+ last_token_hidden_states = extract_token_i_hidden_states(
99
+ self.model, self.tokenizer, representation_input, token_idx_to_extract=self.representation_token_idx, return_dict=False, verbose=False)
100
+
101
+ return last_token_hidden_states
102
+
103
+ def get_hidden_states_and_retrieve_word(self, word, num_tokens_to_generate=None):
104
+ last_token_hidden_states = self.extract_hidden_states(word)
105
+ patchscopes_description_by_layers = self.retrieve_word(
106
+ last_token_hidden_states, num_tokens_to_generate=num_tokens_to_generate)
107
+
108
+ return patchscopes_description_by_layers, last_token_hidden_states
109
+
110
+
111
+ class ReverseLogitLensRetriever(WordRetrieverBase):
112
+ def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
113
+ super().__init__(model, tokenizer)
114
+ self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
115
+
116
+ def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
117
+ result = self.reverse_logit_lens(hidden_states, layer_idx)
118
+ token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
119
+ return token
120
+
121
+
122
+ class AnalysisWordRetriever:
123
+ def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate=1, add_context=True,
124
+ model_name='LLaMa-2B', device='cuda', dataset=None):
125
+ self.model = model.to(device)
126
+ self.tokenizer = tokenizer
127
+ self.multi_token_kind = multi_token_kind
128
+ self.num_tokens_to_generate = num_tokens_to_generate
129
+ self.add_context = add_context
130
+ self.model_name = model_name
131
+ self.device = device
132
+ self.dataset = dataset
133
+ self.retriever = self._initialize_retriever()
134
+ self.RetrievalTechniques = (RetrievalTechniques.Patchscopes if self.multi_token_kind == MultiTokenKind.Natural
135
+ else RetrievalTechniques.ReverseLogitLens)
136
+ self.whitespace_token = 'Ġ' if model_name in ['gemma-2-9b', 'pythia-6.9b', 'LLaMA3-8B', 'Yi-6B'] else '▁'
137
+ self.processor = RetrievalProcessor(self.model, self.tokenizer, self.multi_token_kind,
138
+ self.num_tokens_to_generate, self.add_context, self.model_name,
139
+ self.whitespace_token)
140
+
141
+ def _initialize_retriever(self):
142
+ if self.multi_token_kind == MultiTokenKind.Natural:
143
+ return PatchscopesRetriever(self.model, self.tokenizer)
144
+ else:
145
+ return ReverseLogitLensRetriever(self.model, self.tokenizer)
146
+
147
+ def retrieve_words_in_dataset(self, number_of_examples_to_retrieve=2, max_length=1000):
148
+ self.model.eval()
149
+ results = []
150
+
151
+ for text in tqdm(self.dataset['train']['text'][:number_of_examples_to_retrieve], self.model_name):
152
+ tokenized_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length).to(
153
+ self.device)
154
+ tokens = tokenized_input.input_ids[0]
155
+ print(f'Processing text: {text}')
156
+ i = 5
157
+ while i < len(tokens):
158
+ if self.multi_token_kind == MultiTokenKind.Natural:
159
+ j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_word(
160
+ tokens, i, device=self.device)
161
+ elif self.multi_token_kind == MultiTokenKind.Typo:
162
+ j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_typo(
163
+ tokens, i, device=self.device)
164
+ else:
165
+ j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_separated(
166
+ tokens, i, device=self.device)
167
+
168
+ if len(word_tokens) > 1:
169
+ with torch.no_grad():
170
+ outputs = self.model(**tokenized_combined_text, output_hidden_states=True)
171
+
172
+ hidden_states = outputs.hidden_states
173
+ for layer_idx, hidden_state in enumerate(hidden_states):
174
+ postfix_hidden_state = hidden_states[layer_idx][0, -1, :].unsqueeze(0)
175
+ retrieved_word_str = self.retriever.retrieve_word(postfix_hidden_state, layer_idx=layer_idx,
176
+ num_tokens_to_generate=len(word_tokens))
177
+ results.append({
178
+ 'text': combined_text,
179
+ 'original_word': original_word,
180
+ 'word': word,
181
+ 'word_tokens': self.tokenizer.convert_ids_to_tokens(word_tokens),
182
+ 'num_tokens': len(word_tokens),
183
+ 'layer': layer_idx,
184
+ 'retrieved_word_str': retrieved_word_str,
185
+ 'context': "With Context" if self.add_context else "Without Context"
186
+ })
187
+ else:
188
+ i = j
189
+ return results