File size: 13,542 Bytes
4d6d2dc 9fa8328 4d6d2dc 9b5c8c6 4d6d2dc 2533b7f 4d6d2dc 9fa8328 077e2b3 4d6d2dc 98858d4 4d6d2dc bc2f9ff 4d6d2dc 6a634a2 63981db fe6d32c fb3b2a1 04aee9f fe6d32c b30e55a 4d6d2dc f2d60cb ce4a0c3 4d6d2dc f2d60cb ce4a0c3 4d6d2dc f2d60cb ce4a0c3 4d6d2dc f2d60cb 4d6d2dc 3a7cf2a bf7c081 3a7cf2a 4d6d2dc 9fa8328 4d6d2dc 1a614f9 9fa8328 049eed9 9fa8328 e1cea83 9fa8328 4d6d2dc 5daf90b 4d6d2dc b30a06e de099ae cf4e80d 9fa8328 e1cea83 d8c5a8d 20c0832 e1cea83 bc2f9ff de099ae b30a06e 9fa8328 4d6d2dc 9fa8328 4d6d2dc bc2f9ff 9fa8328 4d6d2dc de099ae 518abab 5ae57a2 518abab 4d6d2dc 049eed9 03ac798 3e36699 049eed9 7f2e668 cf4e80d 049eed9 d028e6b b30e55a e036b8e 9fa8328 4d6d2dc 0d6b098 5ae57a2 b29377d 9f98ca2 b29377d 9f98ca2 868605b 9f98ca2 868605b cf4e80d 9f98ca2 6b4003c 868605b 9fa8328 643b640 87f00d7 b3fa350 9fa8328 b3fa350 bfae66e 9fa8328 b3fa350 4d6d2dc b30e55a bc2f9ff 5e8b4c1 fd25f6d cf4e80d 5e8b4c1 cf4e80d 9fa8328 cf4e80d 3a7cf2a 4d6d2dc 9fa8328 cf4e80d 9fa8328 643b640 e1cea83 dafad0d e1cea83 9fa8328 5e8b4c1 bc2f9ff 5e8b4c1 63981db 5e8b4c1 765296c af61663 e1cea83 0023648 4d6d2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import os
import gc
from typing import Optional
from dataclasses import dataclass
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from datasets import load_dataset
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt
MAX_PROMPT_TOKENS = 60
## info
dataset_info = [
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
'filter': lambda x: x['label'] == 1},
# {'name': 'Physical Understanding', 'hf_repo': 'piqa', 'text_col': 'goal'},
{'name': 'Social Reasoning', 'hf_repo': 'ProlificAI/social-reasoning-rlhf', 'text_col': 'question'}
]
model_info = {
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
), # , load_in_8bit=True
'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<bos>{prompt}',
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
),
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
original_prompt_template='<s>{prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
),
# 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
# tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
# model_type='llama', hf=True, ctransformers=True,
# original_prompt_template='<s>[INST] {prompt} [/INST]',
# interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
# )
}
suggested_interpretation_prompts = [
"Sure, here's a bullet list of the key words in your message:",
"Sure, I'll summarize your message:",
"Sure, here are the words in your message:",
"Before responding, let me repeat the message you wrote:",
"Let me repeat the message:"
]
@dataclass
class GlobalState:
tokenizer : Optional[PreTrainedTokenizer] = None
model : Optional[PreTrainedModel] = None
hidden_states : Optional[torch.Tensor] = None
interpretation_prompt_template : str = '{prompt}'
original_prompt_template : str = '{prompt}'
## functions
@spaces.GPU
def initialize_gpu():
pass
def reset_model(model_name):
# extract model info
model_args = deepcopy(model_info[model_name])
model_path = model_args.pop('model_path')
global_state.original_prompt_template = model_args.pop('original_prompt_template')
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
use_ctransformers = model_args.pop('ctransformers', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
# get model
global_state.model, global_state.tokenizer, global_state.hidden_states = None, None, None
gc.collect()
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
gc.collect()
def get_hidden_states(raw_original_prompt):
model, tokenizer = global_state.model, global_state.tokenizer
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
token_btns = ([gr.Button(token, visible=True) for token in tokens]
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
progress_dummy_output = ''
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
global_state.hidden_states = hidden_states
return [progress_dummy_output, *token_btns, *invisible_bubbles]
@spaces.GPU
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
num_beams=1):
interpreted_vectors = global_state.hidden_states[:, i]
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
# generation parameters
generation_kwargs = {
'max_new_tokens': int(max_new_tokens),
'do_sample': do_sample,
'temperature': temperature,
'top_k': int(top_k),
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'length_penalty': length_penalty,
'num_beams': int(num_beams)
}
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
# generate the interpretations
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
generation_texts = tokenizer.batch_decode(generated)
progress_dummy_output = ''
return ([progress_dummy_output] +
[gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
)
## main
torch.set_grad_enabled(False)
global_state = GlobalState()
model_name = 'LLAMA2-7B'
reset_model(model_name)
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
tokens_container = []
for i in range(MAX_PROMPT_TOKENS):
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
tokens_container.append(btn)
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
with gr.Row():
with gr.Column(scale=5):
gr.Markdown('# π Self-Interpreting Models')
gr.Markdown('<b style="color: #8B0000;">Model outputs are not filtered and might include undesired language!</b>')
# gr.Markdown(
# '**πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ**',
# # elem_classes=['explanation_accordion']
# )
gr.Markdown(
'''
**πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ**
This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work π₯³) which was less mature but had the same idea in mind.
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
''', line_breaks=True)
# gr.Markdown('**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ**',
# # elem_classes=['explanation_accordion']
# )
gr.Markdown(
'''
**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ**
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers.
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand,
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! π―π―π―
''', line_breaks=True)
# with gr.Column(scale=1):
# gr.Markdown('<span style="font-size:180px;">π€</span>')
with gr.Group():
model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name)
gr.Markdown('## Choose Your Interpretation Prompt')
with gr.Group('Interpretation'):
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False)
gr.Markdown('## The Prompt to Analyze')
for info in dataset_info:
with gr.Tab(info['name']):
num_examples = 10
dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
if 'filter' in info:
dataset = dataset.filter(info['filter'])
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
dataset = [[row[info['text_col']]] for row in dataset]
gr.Examples(dataset, [global_state, original_prompt_raw], cache_examples=False)
with gr.Group():
original_prompt_raw.render()
original_prompt_btn = gr.Button('Output Token List', variant='primary')
gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
with gr.Row():
for btn in tokens_container:
btn.render()
with gr.Accordion(open=False, label='Generation Settings'):
with gr.Row():
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
# num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
do_sample = gr.Checkbox(label='With sampling')
with gr.Accordion(label='Sampling Parameters'):
with gr.Row():
temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
top_p = gr.Slider(0., 1., value=0.95, label='top p')
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
) for i in range(model.config.num_hidden_layers)]
# event listeners
model_chooser.change(reset_new_model, [model_chooser], [])
for i, btn in enumerate(tokens_container):
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
num_tokens, do_sample, temperature,
top_k, top_p, repetition_penalty, length_penalty,
], [progress_dummy, *interpretation_bubbles])
original_prompt_btn.click(get_hidden_states,
[original_prompt_raw],
[progress_dummy, *tokens_container, *interpretation_bubbles])
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
demo.launch() |