Update app.py
Browse files
app.py
CHANGED
@@ -45,7 +45,7 @@ suggested_interpretation_prompts = ["Before responding, let me repeat the messag
|
|
45 |
def initialize_gpu():
|
46 |
pass
|
47 |
|
48 |
-
def get_hidden_states(raw_original_prompt):
|
49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
@@ -56,7 +56,7 @@ def get_hidden_states(raw_original_prompt):
|
|
56 |
return [hidden_states, *token_btns]
|
57 |
|
58 |
|
59 |
-
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
60 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
61 |
num_beams=1):
|
62 |
|
@@ -82,7 +82,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
82 |
# generate the interpretations
|
83 |
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
84 |
generation_texts = tokenizer.batch_decode(generated)
|
85 |
-
return [gr.
|
86 |
|
87 |
|
88 |
## main
|
@@ -137,7 +137,7 @@ css = '''
|
|
137 |
|
138 |
# '''
|
139 |
|
140 |
-
|
141 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
142 |
global_state = gr.State([])
|
143 |
with gr.Row():
|
@@ -147,7 +147,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
147 |
|
148 |
👾 **This space is a simple introduction to the emerging trend of models interpreting their _own hidden states_ in free form natural language**!! 👾
|
149 |
This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
|
150 |
-
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my own work!! 🥳) which was
|
151 |
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
|
152 |
|
153 |
👾 **The idea is really simple: models are able to understand their own hidden states by nature!** 👾
|
@@ -158,7 +158,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
158 |
gr.Markdown('<span style="font-size:180px;">🤔</span>')
|
159 |
|
160 |
with gr.Group():
|
161 |
-
original_prompt_raw = gr.Textbox(value='
|
162 |
original_prompt_btn = gr.Button('Compute', variant='primary')
|
163 |
|
164 |
with gr.Accordion(open=False, label='Settings'):
|
@@ -179,21 +179,21 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
179 |
|
180 |
with gr.Group('Output'):
|
181 |
tokens_container = []
|
182 |
-
interpretation_bubbles = []
|
183 |
with gr.Row():
|
184 |
for i in range(MAX_PROMPT_TOKENS):
|
185 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
186 |
tokens_container.append(btn)
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
for i, btn in enumerate(tokens_container):
|
192 |
-
btn.click(partial(run_interpretation, i=i), [
|
193 |
-
|
|
|
194 |
], [*interpretation_bubbles])
|
195 |
|
196 |
original_prompt_btn.click(get_hidden_states,
|
197 |
-
[original_prompt_raw],
|
198 |
[global_state, *tokens_container])
|
199 |
demo.launch()
|
|
|
45 |
def initialize_gpu():
|
46 |
pass
|
47 |
|
48 |
+
def get_hidden_states(progress, raw_original_prompt):
|
49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
|
|
56 |
return [hidden_states, *token_btns]
|
57 |
|
58 |
|
59 |
+
def run_interpretation(progress, global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
60 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
61 |
num_beams=1):
|
62 |
|
|
|
82 |
# generate the interpretations
|
83 |
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
84 |
generation_texts = tokenizer.batch_decode(generated)
|
85 |
+
return [gr.TextBox(text, visible=True, container=False) for text in generation_texts]
|
86 |
|
87 |
|
88 |
## main
|
|
|
137 |
|
138 |
# '''
|
139 |
|
140 |
+
progress = gr.Progress()
|
141 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
142 |
global_state = gr.State([])
|
143 |
with gr.Row():
|
|
|
147 |
|
148 |
👾 **This space is a simple introduction to the emerging trend of models interpreting their _own hidden states_ in free form natural language**!! 👾
|
149 |
This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
|
150 |
+
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my own work!! 🥳) which was less mature but had the same idea in mind.
|
151 |
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
|
152 |
|
153 |
👾 **The idea is really simple: models are able to understand their own hidden states by nature!** 👾
|
|
|
158 |
gr.Markdown('<span style="font-size:180px;">🤔</span>')
|
159 |
|
160 |
with gr.Group():
|
161 |
+
original_prompt_raw = gr.Textbox(value='Should I eat cake or vegetables?', container=True, label='Original Prompt')
|
162 |
original_prompt_btn = gr.Button('Compute', variant='primary')
|
163 |
|
164 |
with gr.Accordion(open=False, label='Settings'):
|
|
|
179 |
|
180 |
with gr.Group('Output'):
|
181 |
tokens_container = []
|
|
|
182 |
with gr.Row():
|
183 |
for i in range(MAX_PROMPT_TOKENS):
|
184 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
185 |
tokens_container.append(btn)
|
186 |
+
progress.render()
|
187 |
+
interpretation_bubbles = [gr.TextBox('', container=False, visible=False, elem_classes=['bubble'])
|
188 |
+
for i in range(model.config.num_hidden_layers)]
|
189 |
+
|
190 |
for i, btn in enumerate(tokens_container):
|
191 |
+
btn.click(partial(run_interpretation, i=i), [progress,
|
192 |
+
global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
193 |
+
top_k, top_p, repetition_penalty, length_penalty
|
194 |
], [*interpretation_bubbles])
|
195 |
|
196 |
original_prompt_btn.click(get_hidden_states,
|
197 |
+
[progress, original_prompt_raw],
|
198 |
[global_state, *tokens_container])
|
199 |
demo.launch()
|