Spaces:
Sleeping
Sleeping
adding application
Browse files- app.py +125 -4
- requirements.txt +213 -0
- word_retriever.py +189 -0
app.py
CHANGED
@@ -1,7 +1,128 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
from functools import lru_cache
|
6 |
|
7 |
+
# ----------------------------------------------------------------------
|
8 |
+
# IMPORTANT: This version uses the PatchscopesRetriever implementation
|
9 |
+
# from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words)
|
10 |
+
# ----------------------------------------------------------------------
|
11 |
+
try:
|
12 |
+
from word_retriever import PatchscopesRetriever # pip install tokens2words
|
13 |
+
except ImportError:
|
14 |
+
PatchscopesRetriever = None
|
15 |
|
16 |
+
DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere
|
17 |
+
DEVICE = 'mps'
|
18 |
+
# (
|
19 |
+
# "cuda" if torch.cuda.is_available() else ("mps" if torch.word_retriever.pybackends.mps.is_available() else "cpu")
|
20 |
+
# )
|
21 |
+
|
22 |
+
@lru_cache(maxsize=4)
|
23 |
+
def get_model_and_tokenizer(model_name: str):
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
model_name,
|
27 |
+
torch_dtype=torch.bfloat16 ,
|
28 |
+
output_hidden_states=True,
|
29 |
+
).to(DEVICE)
|
30 |
+
model.eval()
|
31 |
+
return model, tokenizer
|
32 |
+
|
33 |
+
|
34 |
+
def find_last_token_index(full_ids, word_ids):
|
35 |
+
"""Locate end position of word_ids inside full_ids (first match)."""
|
36 |
+
for i in range(len(full_ids) - len(word_ids) + 1):
|
37 |
+
if full_ids[i : i + len(word_ids)] == word_ids:
|
38 |
+
return i + len(word_ids) - 1
|
39 |
+
return None
|
40 |
+
|
41 |
+
|
42 |
+
def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
|
43 |
+
if PatchscopesRetriever is None:
|
44 |
+
return (
|
45 |
+
"<p style='color:red'>❌ Patchscopes library not found. Run:<br/>"
|
46 |
+
"<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>"
|
47 |
+
)
|
48 |
+
|
49 |
+
model, tokenizer = get_model_and_tokenizer(model_name)
|
50 |
+
|
51 |
+
# Build extraction prompt (where hidden states will be collected)
|
52 |
+
extraction_prompt ="X"
|
53 |
+
|
54 |
+
# Identify last token position of the *word* inside the prompt IDs
|
55 |
+
word_token_ids = tokenizer.encode(word, add_special_tokens=False)
|
56 |
+
|
57 |
+
# Instantiate Patchscopes retriever
|
58 |
+
patch_retriever = PatchscopesRetriever(
|
59 |
+
model,
|
60 |
+
tokenizer,
|
61 |
+
extraction_prompt,
|
62 |
+
patchscopes_template,
|
63 |
+
prompt_target_placeholder="X",
|
64 |
+
)
|
65 |
+
|
66 |
+
# Run retrieval for the word across all layers (one pass)
|
67 |
+
retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
|
68 |
+
word,
|
69 |
+
num_tokens_to_generate=len(tokenizer.tokenize(word)),
|
70 |
+
)[0]
|
71 |
+
|
72 |
+
# Build a table summarising which layers match
|
73 |
+
records = []
|
74 |
+
matches = 0
|
75 |
+
for layer_idx, ret_word in enumerate(retrieved_words):
|
76 |
+
match = ret_word.strip(" ") == word.strip(" ")
|
77 |
+
if match:
|
78 |
+
matches += 1
|
79 |
+
records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""})
|
80 |
+
|
81 |
+
df = pd.DataFrame(records)
|
82 |
+
|
83 |
+
def _style(row):
|
84 |
+
color = "background-color: lightgreen" if row["Match?"] else ""
|
85 |
+
return [color] * len(row)
|
86 |
+
|
87 |
+
html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False)
|
88 |
+
|
89 |
+
sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids)
|
90 |
+
top = (
|
91 |
+
f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>"
|
92 |
+
f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>"
|
93 |
+
)
|
94 |
+
return top + html_table
|
95 |
+
|
96 |
+
|
97 |
+
# ----------------------------- GRADIO UI -------------------------------
|
98 |
+
with gr.Blocks(theme="soft") as demo:
|
99 |
+
gr.Markdown(
|
100 |
+
"""# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers."""
|
101 |
+
)
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
model_name = gr.Dropdown(
|
105 |
+
label="🤖 Model",
|
106 |
+
choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b", "Qwen/Qwen2-7B"],
|
107 |
+
value=DEFAULT_MODEL,
|
108 |
+
)
|
109 |
+
extraction_template = gr.Textbox(
|
110 |
+
label="Extraction prompt (use X as placeholder)",
|
111 |
+
value="repeat the following word X twice: 1)X 2)",
|
112 |
+
)
|
113 |
+
patchscopes_template = gr.Textbox(
|
114 |
+
label="Patchscopes prompt (use X as placeholder)",
|
115 |
+
value="repeat the following word X twice: 1)X 2)",
|
116 |
+
)
|
117 |
+
word_box = gr.Textbox(label="Word to test", value="interpretable")
|
118 |
+
run_btn = gr.Button("Analyse")
|
119 |
+
out_html = gr.HTML()
|
120 |
+
|
121 |
+
run_btn.click(
|
122 |
+
analyse_word,
|
123 |
+
inputs=[model_name, extraction_template, word_box, patchscopes_template],
|
124 |
+
outputs=out_html,
|
125 |
+
)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.2.1
|
2 |
+
aiofiles==23.2.1
|
3 |
+
aiohappyeyeballs==2.4.4
|
4 |
+
aiohttp==3.11.11
|
5 |
+
aiosignal==1.3.2
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_68kdsx8iyd/croot/anyio_1729121281958/work
|
8 |
+
appnope @ file:///Users/ktietz/demo/mc3/conda-bld/appnope_1629146036738/work
|
9 |
+
argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work
|
10 |
+
argon2-cffi-bindings @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2ef471wnyf/croot/argon2-cffi-bindings_1736182451265/work
|
11 |
+
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
|
12 |
+
async-lru @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_02efro5ps8/croot/async-lru_1699554529181/work
|
13 |
+
async-timeout==5.0.1
|
14 |
+
attrs @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93pjmt0git/croot/attrs_1734533120523/work
|
15 |
+
Babel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_00k1rl2pus/croot/babel_1671781944131/work
|
16 |
+
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
|
17 |
+
beautifulsoup4 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_94rx5n7wo9/croot/beautifulsoup4-split_1718029832430/work
|
18 |
+
bleach @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_faqg19k8gh/croot/bleach_1732292152791/work
|
19 |
+
blis==1.2.0
|
20 |
+
Brotli @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f7i0oxypt6/croot/brotli-split_1736182464088/work
|
21 |
+
catalogue==2.0.10
|
22 |
+
certifi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d8j59rqun5/croot/certifi_1734473289913/work/certifi
|
23 |
+
cffi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e4xd9yd9i2/croot/cffi_1736182819442/work
|
24 |
+
charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
|
25 |
+
click==8.1.8
|
26 |
+
cloudpathlib==0.20.0
|
27 |
+
cloudpickle==3.1.0
|
28 |
+
comm @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_3doui0bmzb/croot/comm_1709322861485/work
|
29 |
+
confection==0.1.5
|
30 |
+
contourpy==1.3.0
|
31 |
+
cycler==0.12.1
|
32 |
+
cymem==2.0.11
|
33 |
+
dask==2024.8.0
|
34 |
+
dask-expr==1.1.10
|
35 |
+
datasets==3.2.0
|
36 |
+
debugpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_563_nwtkoc/croot/debugpy_1690905063850/work
|
37 |
+
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
|
38 |
+
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
|
39 |
+
-e git+https://github.com/tokeron/diffusers.git@00769b5d64c2ea35201e0df7a082db3513619afe#egg=diffusers&subdirectory=../../../../../../diffusers
|
40 |
+
dill==0.3.8
|
41 |
+
distro==1.9.0
|
42 |
+
docker-pycreds==0.4.0
|
43 |
+
editdistance==0.8.1
|
44 |
+
en_core_web_lg @ https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl#sha256=293e9547a655b25499198ab15a525b05b9407a75f10255e405e8c3854329ab63
|
45 |
+
en_core_web_md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl#sha256=5e6329fe3fecedb1d1a02c3ea2172ee0fede6cea6e4aefb6a02d832dba78a310
|
46 |
+
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
|
47 |
+
eval_type_backport==0.2.2
|
48 |
+
exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work
|
49 |
+
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
|
50 |
+
fastapi==0.115.12
|
51 |
+
fastjsonschema @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d1wgyi4enb/croot/python-fastjsonschema_1731939426145/work
|
52 |
+
ffmpy==0.5.0
|
53 |
+
filelock==3.16.1
|
54 |
+
fonttools==4.55.3
|
55 |
+
frozenlist==1.5.0
|
56 |
+
fsspec==2024.9.0
|
57 |
+
gitdb==4.0.12
|
58 |
+
GitPython==3.1.44
|
59 |
+
gradio==4.44.1
|
60 |
+
gradio_client==1.3.0
|
61 |
+
h11 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_110bmw2coo/croot/h11_1706652289620/work
|
62 |
+
httpcore @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_fcxiho9nv7/croot/httpcore_1706728465004/work
|
63 |
+
httpx @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cc4egw1482/croot/httpx_1723474826664/work
|
64 |
+
huggingface-hub==0.27.1
|
65 |
+
idna==3.10
|
66 |
+
importlib_metadata @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_cc4qelzghy/croot/importlib_metadata-suite_1732633706960/work
|
67 |
+
importlib_resources==6.5.2
|
68 |
+
ipykernel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ddflobe9t3/croot/ipykernel_1728665605034/work
|
69 |
+
ipython @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_6599f73fa7/croot/ipython_1694181355402/work
|
70 |
+
jedi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_38ctoinnl0/croot/jedi_1733987402850/work
|
71 |
+
Jinja2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b15nuwux5r/croot/jinja2_1730902833938/work
|
72 |
+
jiter==0.8.2
|
73 |
+
joblib==1.4.2
|
74 |
+
json5 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b9ww6ewhv3/croot/json5_1730786813588/work
|
75 |
+
jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_7boelfqucq/croot/jsonschema_1728486715888/work
|
76 |
+
jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work
|
77 |
+
jupyter-events @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_db0avcjzq5/croot/jupyter_events_1718738111427/work
|
78 |
+
jupyter-lsp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_ae9br5v37x/croot/jupyter-lsp-meta_1699978259353/work
|
79 |
+
jupyter_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_58w2siozyz/croot/jupyter_client_1699455907045/work
|
80 |
+
jupyter_core @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_73nomeum4p/croot/jupyter_core_1718818302815/work
|
81 |
+
jupyter_server @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d1t69bk94b/croot/jupyter_server_1718827086930/work
|
82 |
+
jupyter_server_terminals @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e7ryd60iuw/croot/jupyter_server_terminals_1686870731283/work
|
83 |
+
jupyterlab @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a2d0br6r6g/croot/jupyterlab_1725895226942/work
|
84 |
+
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
|
85 |
+
jupyterlab_server @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f64fg3hglz/croot/jupyterlab_server_1725865356410/work
|
86 |
+
kiwisolver==1.4.7
|
87 |
+
langcodes==3.5.0
|
88 |
+
language_data==1.3.0
|
89 |
+
locket==1.0.0
|
90 |
+
marisa-trie==1.2.1
|
91 |
+
markdown-it-py==3.0.0
|
92 |
+
MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
|
93 |
+
matplotlib==3.9.4
|
94 |
+
matplotlib-inline @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f6fdc0hldi/croots/recipe/matplotlib-inline_1662014472341/work
|
95 |
+
matplotlib-venn==1.1.2
|
96 |
+
mdurl==0.1.2
|
97 |
+
mistune @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_17ya6k1sbs/croots/recipe/mistune_1661496228719/work
|
98 |
+
mpmath==1.3.0
|
99 |
+
multidict==6.1.0
|
100 |
+
multiprocess==0.70.16
|
101 |
+
murmurhash==1.0.12
|
102 |
+
nbclient @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_626hpwnurm/croot/nbclient_1698934218848/work
|
103 |
+
nbconvert @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f4c1s1qk1f/croot/nbconvert_1728049432295/work
|
104 |
+
nbformat @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2cv_qoc1gw/croot/nbformat_1728049423516/work
|
105 |
+
nest-asyncio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_310vb5e2a0/croot/nest-asyncio_1708532678212/work
|
106 |
+
networkx==3.2.1
|
107 |
+
nltk==3.9.1
|
108 |
+
notebook @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_539v4hufo2/croot/notebook_1727199149603/work
|
109 |
+
notebook_shim @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d6_ze10f45/croot/notebook-shim_1699455897525/work
|
110 |
+
numpy==2.0.2
|
111 |
+
openai==1.59.7
|
112 |
+
orjson==3.10.16
|
113 |
+
overrides @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70s80guh9g/croot/overrides_1699371144462/work
|
114 |
+
packaging @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a6_qk3qyg7/croot/packaging_1734472142254/work
|
115 |
+
pandas==2.2.3
|
116 |
+
pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
|
117 |
+
parso @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_8824a1w4md/croot/parso_1733963320105/work
|
118 |
+
partd==1.4.2
|
119 |
+
patsy==1.0.1
|
120 |
+
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
|
121 |
+
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
|
122 |
+
pillow==10.4.0
|
123 |
+
platformdirs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8u4fy8k9o/croot/platformdirs_1692205661656/work
|
124 |
+
plotly==5.24.1
|
125 |
+
preshed==3.0.9
|
126 |
+
prometheus_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_803ymjpv2u/croot/prometheus_client_1731958793251/work
|
127 |
+
prompt-toolkit @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_c63v4kqjzr/croot/prompt-toolkit_1704404354115/work
|
128 |
+
propcache==0.2.1
|
129 |
+
protobuf==5.29.2
|
130 |
+
psutil @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1310b568-21f4-4cb0-b0e3-2f3d31e39728k9coaga5/croots/recipe/psutil_1656431280844/work
|
131 |
+
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
|
132 |
+
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
|
133 |
+
pyarrow==18.1.0
|
134 |
+
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
|
135 |
+
pydantic==2.10.4
|
136 |
+
pydantic_core==2.27.2
|
137 |
+
pydub==0.25.1
|
138 |
+
Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work
|
139 |
+
pyparsing==3.2.1
|
140 |
+
PySocks @ file:///Users/ktietz/Code/oss/ci_pkgs/pysocks_1626781349491/work
|
141 |
+
python-box==7.3.0
|
142 |
+
python-dateutil @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_66ud1l42_h/croot/python-dateutil_1716495741162/work
|
143 |
+
python-json-logger @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9bjmcmh4nm/croot/python-json-logger_1734370248301/work
|
144 |
+
python-multipart==0.0.20
|
145 |
+
pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a4b76c83ik/croot/pytz_1713974318928/work
|
146 |
+
PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_faoex52hrr/croot/pyyaml_1728657970485/work
|
147 |
+
pyzmq @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_95lsut8ymz/croot/pyzmq_1734709560733/work
|
148 |
+
referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work
|
149 |
+
regex==2024.11.6
|
150 |
+
requests @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ee45nsd33z/croot/requests_1730999134038/work
|
151 |
+
rfc3339-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_76ae5cu30h/croot/rfc3339-validator_1683077051957/work
|
152 |
+
rfc3986-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0l5zd97kt/croot/rfc3986-validator_1683058998431/work
|
153 |
+
rich==13.9.4
|
154 |
+
rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_93fzmr7v9h/croot/rpds-py_1732228422522/work
|
155 |
+
ruff==0.11.6
|
156 |
+
safetensors==0.5.0
|
157 |
+
scikit-learn==1.6.0
|
158 |
+
scipy==1.13.1
|
159 |
+
seaborn==0.13.2
|
160 |
+
semantic-version==2.10.0
|
161 |
+
Send2Trash @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5b31f0zzlv/croot/send2trash_1699371144121/work
|
162 |
+
sentencepiece==0.2.0
|
163 |
+
sentry-sdk==2.19.2
|
164 |
+
setproctitle==1.3.4
|
165 |
+
shellingham==1.5.4
|
166 |
+
six @ file:///tmp/build/80754af9/six_1644875935023/work
|
167 |
+
smart-open==7.1.0
|
168 |
+
smmap==5.0.2
|
169 |
+
sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work
|
170 |
+
soupsieve @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9798xzs_03/croot/soupsieve_1696347567192/work
|
171 |
+
spacy==3.8.3
|
172 |
+
spacy-legacy==3.0.12
|
173 |
+
spacy-loggers==1.0.5
|
174 |
+
srsly==2.5.1
|
175 |
+
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
|
176 |
+
starlette==0.46.2
|
177 |
+
statsmodels==0.14.4
|
178 |
+
swifter==1.4.0
|
179 |
+
sympy==1.13.1
|
180 |
+
tabulate==0.9.0
|
181 |
+
tenacity==9.0.0
|
182 |
+
terminado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcfvyc0an2/croot/terminado_1671751835701/work
|
183 |
+
thinc==8.3.4
|
184 |
+
threadpoolctl==3.5.0
|
185 |
+
tinycss2 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcw5_i306t/croot/tinycss2_1668168825117/work
|
186 |
+
together==1.4.1
|
187 |
+
tokenizers==0.21.0
|
188 |
+
tomli @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0e5ffbf-5cf1-45be-8693-c5dff8108a2awhthtjlq/croots/recipe/tomli_1657175508477/work
|
189 |
+
tomlkit==0.12.0
|
190 |
+
toolz==1.0.0
|
191 |
+
torch==2.5.1
|
192 |
+
torchvision==0.20.1
|
193 |
+
tornado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_0axef5a0m0/croot/tornado_1733960501260/work
|
194 |
+
tqdm==4.67.1
|
195 |
+
traitlets @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_500m2_1wyk/croot/traitlets_1718227071952/work
|
196 |
+
transformers==4.47.1
|
197 |
+
typer==0.15.1
|
198 |
+
typing_extensions @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_0b3jpv_f79/croot/typing_extensions_1734714864260/work
|
199 |
+
tzdata==2024.2
|
200 |
+
urllib3 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_06_m8gdsy6/croot/urllib3_1727769822458/work
|
201 |
+
uvicorn==0.34.2
|
202 |
+
wandb==0.19.1
|
203 |
+
wasabi==1.1.3
|
204 |
+
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
|
205 |
+
weasel==0.4.1
|
206 |
+
webencodings==0.5.1
|
207 |
+
websocket-client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d37u7gqts8/croot/websocket-client_1715878310260/work
|
208 |
+
websockets==12.0
|
209 |
+
wordcloud==1.9.4
|
210 |
+
wrapt==1.17.2
|
211 |
+
xxhash==3.5.0
|
212 |
+
yarl==1.18.3
|
213 |
+
zipp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_echurpkwug/croot/zipp_1732630743967/work
|
word_retriever.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
|
5 |
+
from .utils.enums import MultiTokenKind, RetrievalTechniques
|
6 |
+
from .processor import RetrievalProcessor
|
7 |
+
from .utils.logit_lens import ReverseLogitLens
|
8 |
+
from .utils.model_utils import extract_token_i_hidden_states
|
9 |
+
|
10 |
+
|
11 |
+
class WordRetrieverBase(ABC):
|
12 |
+
def __init__(self, model, tokenizer):
|
13 |
+
self.model = model
|
14 |
+
self.tokenizer = tokenizer
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
class PatchscopesRetriever(WordRetrieverBase):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model,
|
25 |
+
tokenizer,
|
26 |
+
representation_prompt: str = "{word}",
|
27 |
+
patchscopes_prompt: str = "Next is the same word twice: 1) {word} 2)",
|
28 |
+
prompt_target_placeholder: str = "{word}",
|
29 |
+
representation_token_idx_to_extract: int = -1,
|
30 |
+
num_tokens_to_generate: int = 10,
|
31 |
+
):
|
32 |
+
super().__init__(model, tokenizer)
|
33 |
+
self.prompt_input_ids, self.prompt_target_idx = \
|
34 |
+
self._build_prompt_input_ids_template(patchscopes_prompt, prompt_target_placeholder)
|
35 |
+
self._prepare_representation_prompt = \
|
36 |
+
self._build_representation_prompt_func(representation_prompt, prompt_target_placeholder)
|
37 |
+
self.representation_token_idx = representation_token_idx_to_extract
|
38 |
+
self.num_tokens_to_generate = num_tokens_to_generate
|
39 |
+
|
40 |
+
def _build_prompt_input_ids_template(self, prompt, target_placeholder):
|
41 |
+
prompt_input_ids = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id is not None else []
|
42 |
+
target_idx = []
|
43 |
+
|
44 |
+
if prompt:
|
45 |
+
assert target_placeholder is not None, \
|
46 |
+
"Trying to set a prompt for Patchscopes without defining the prompt's target placeholder string, e.g., [MASK]"
|
47 |
+
|
48 |
+
prompt_parts = prompt.split(target_placeholder)
|
49 |
+
for part_i, prompt_part in enumerate(prompt_parts):
|
50 |
+
prompt_input_ids += self.tokenizer.encode(prompt_part, add_special_tokens=False)
|
51 |
+
if part_i < len(prompt_parts)-1:
|
52 |
+
target_idx += [len(prompt_input_ids)]
|
53 |
+
prompt_input_ids += [0]
|
54 |
+
else:
|
55 |
+
prompt_input_ids += [0]
|
56 |
+
target_idx = [len(prompt_input_ids)]
|
57 |
+
|
58 |
+
prompt_input_ids = torch.tensor(prompt_input_ids, dtype=torch.long)
|
59 |
+
target_idx = torch.tensor(target_idx, dtype=torch.long)
|
60 |
+
return prompt_input_ids, target_idx
|
61 |
+
|
62 |
+
def _build_representation_prompt_func(self, prompt, target_placeholder):
|
63 |
+
return lambda word: prompt.replace(target_placeholder, word)
|
64 |
+
|
65 |
+
def generate_states(self, tokenizer, word='Wakanda', with_prompt=True):
|
66 |
+
prompt = self.generate_prompt() if with_prompt else word
|
67 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
68 |
+
return input_ids
|
69 |
+
|
70 |
+
def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=None):
|
71 |
+
self.model.eval()
|
72 |
+
|
73 |
+
# insert hidden states into patchscopes prompt
|
74 |
+
if hidden_states.dim() == 1:
|
75 |
+
hidden_states = hidden_states.unsqueeze(0)
|
76 |
+
|
77 |
+
inputs_embeds = self.model.get_input_embeddings()(self.prompt_input_ids.to(self.model.device)).unsqueeze(0)
|
78 |
+
batched_patchscope_inputs = inputs_embeds.repeat(len(hidden_states), 1, 1).to(hidden_states.dtype)
|
79 |
+
batched_patchscope_inputs[:, self.prompt_target_idx] = hidden_states.unsqueeze(1).to(self.model.device)
|
80 |
+
|
81 |
+
attention_mask = (self.prompt_input_ids != self.tokenizer.eos_token_id).long().unsqueeze(0).repeat(
|
82 |
+
len(hidden_states), 1).to(self.model.device)
|
83 |
+
|
84 |
+
num_tokens_to_generate = num_tokens_to_generate if num_tokens_to_generate else self.num_tokens_to_generate
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
patchscope_outputs = self.model.generate(
|
88 |
+
do_sample=False, num_beams=1, top_p=1.0, temperature=None,
|
89 |
+
inputs_embeds=batched_patchscope_inputs,# attention_mask=attention_mask,
|
90 |
+
max_new_tokens=num_tokens_to_generate, pad_token_id=self.tokenizer.eos_token_id, )
|
91 |
+
|
92 |
+
decoded_patchscope_outputs = self.tokenizer.batch_decode(patchscope_outputs)
|
93 |
+
return decoded_patchscope_outputs
|
94 |
+
|
95 |
+
def extract_hidden_states(self, word):
|
96 |
+
representation_input = self._prepare_representation_prompt(word)
|
97 |
+
|
98 |
+
last_token_hidden_states = extract_token_i_hidden_states(
|
99 |
+
self.model, self.tokenizer, representation_input, token_idx_to_extract=self.representation_token_idx, return_dict=False, verbose=False)
|
100 |
+
|
101 |
+
return last_token_hidden_states
|
102 |
+
|
103 |
+
def get_hidden_states_and_retrieve_word(self, word, num_tokens_to_generate=None):
|
104 |
+
last_token_hidden_states = self.extract_hidden_states(word)
|
105 |
+
patchscopes_description_by_layers = self.retrieve_word(
|
106 |
+
last_token_hidden_states, num_tokens_to_generate=num_tokens_to_generate)
|
107 |
+
|
108 |
+
return patchscopes_description_by_layers, last_token_hidden_states
|
109 |
+
|
110 |
+
|
111 |
+
class ReverseLogitLensRetriever(WordRetrieverBase):
|
112 |
+
def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
|
113 |
+
super().__init__(model, tokenizer)
|
114 |
+
self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
|
115 |
+
|
116 |
+
def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
|
117 |
+
result = self.reverse_logit_lens(hidden_states, layer_idx)
|
118 |
+
token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
|
119 |
+
return token
|
120 |
+
|
121 |
+
|
122 |
+
class AnalysisWordRetriever:
|
123 |
+
def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate=1, add_context=True,
|
124 |
+
model_name='LLaMa-2B', device='cuda', dataset=None):
|
125 |
+
self.model = model.to(device)
|
126 |
+
self.tokenizer = tokenizer
|
127 |
+
self.multi_token_kind = multi_token_kind
|
128 |
+
self.num_tokens_to_generate = num_tokens_to_generate
|
129 |
+
self.add_context = add_context
|
130 |
+
self.model_name = model_name
|
131 |
+
self.device = device
|
132 |
+
self.dataset = dataset
|
133 |
+
self.retriever = self._initialize_retriever()
|
134 |
+
self.RetrievalTechniques = (RetrievalTechniques.Patchscopes if self.multi_token_kind == MultiTokenKind.Natural
|
135 |
+
else RetrievalTechniques.ReverseLogitLens)
|
136 |
+
self.whitespace_token = 'Ġ' if model_name in ['gemma-2-9b', 'pythia-6.9b', 'LLaMA3-8B', 'Yi-6B'] else '▁'
|
137 |
+
self.processor = RetrievalProcessor(self.model, self.tokenizer, self.multi_token_kind,
|
138 |
+
self.num_tokens_to_generate, self.add_context, self.model_name,
|
139 |
+
self.whitespace_token)
|
140 |
+
|
141 |
+
def _initialize_retriever(self):
|
142 |
+
if self.multi_token_kind == MultiTokenKind.Natural:
|
143 |
+
return PatchscopesRetriever(self.model, self.tokenizer)
|
144 |
+
else:
|
145 |
+
return ReverseLogitLensRetriever(self.model, self.tokenizer)
|
146 |
+
|
147 |
+
def retrieve_words_in_dataset(self, number_of_examples_to_retrieve=2, max_length=1000):
|
148 |
+
self.model.eval()
|
149 |
+
results = []
|
150 |
+
|
151 |
+
for text in tqdm(self.dataset['train']['text'][:number_of_examples_to_retrieve], self.model_name):
|
152 |
+
tokenized_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length).to(
|
153 |
+
self.device)
|
154 |
+
tokens = tokenized_input.input_ids[0]
|
155 |
+
print(f'Processing text: {text}')
|
156 |
+
i = 5
|
157 |
+
while i < len(tokens):
|
158 |
+
if self.multi_token_kind == MultiTokenKind.Natural:
|
159 |
+
j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_word(
|
160 |
+
tokens, i, device=self.device)
|
161 |
+
elif self.multi_token_kind == MultiTokenKind.Typo:
|
162 |
+
j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_typo(
|
163 |
+
tokens, i, device=self.device)
|
164 |
+
else:
|
165 |
+
j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_separated(
|
166 |
+
tokens, i, device=self.device)
|
167 |
+
|
168 |
+
if len(word_tokens) > 1:
|
169 |
+
with torch.no_grad():
|
170 |
+
outputs = self.model(**tokenized_combined_text, output_hidden_states=True)
|
171 |
+
|
172 |
+
hidden_states = outputs.hidden_states
|
173 |
+
for layer_idx, hidden_state in enumerate(hidden_states):
|
174 |
+
postfix_hidden_state = hidden_states[layer_idx][0, -1, :].unsqueeze(0)
|
175 |
+
retrieved_word_str = self.retriever.retrieve_word(postfix_hidden_state, layer_idx=layer_idx,
|
176 |
+
num_tokens_to_generate=len(word_tokens))
|
177 |
+
results.append({
|
178 |
+
'text': combined_text,
|
179 |
+
'original_word': original_word,
|
180 |
+
'word': word,
|
181 |
+
'word_tokens': self.tokenizer.convert_ids_to_tokens(word_tokens),
|
182 |
+
'num_tokens': len(word_tokens),
|
183 |
+
'layer': layer_idx,
|
184 |
+
'retrieved_word_str': retrieved_word_str,
|
185 |
+
'context': "With Context" if self.add_context else "Without Context"
|
186 |
+
})
|
187 |
+
else:
|
188 |
+
i = j
|
189 |
+
return results
|