File size: 10,667 Bytes
4d6d2dc 9b5c8c6 4d6d2dc 077e2b3 4d6d2dc 3a2f9b3 4d6d2dc f2d60cb 681bdc6 4d6d2dc f2d60cb 4d6d2dc f2d60cb 4d6d2dc f2d60cb 4d6d2dc 1a614f9 af61663 4d6d2dc 5daf90b 4d6d2dc b30a06e de099ae b30a06e 2fcc96e b30a06e de099ae b30a06e 4d6d2dc b30a06e 4d6d2dc de099ae 518abab 4d6d2dc 3e36699 4d6d2dc 3e36699 4d6d2dc 3e36699 4d6d2dc 3e36699 4d6d2dc 3a2f9b3 b30a06e 765296c 81e5b58 765296c 5e8b4c1 bca9264 765296c 397b0a7 681bdc6 60f4a55 9d74583 681bdc6 17d4734 4489e5a ada2e9d 4489e5a 765296c b4d2f29 4ece76d edb0c67 4d6d2dc 0d6b098 4489e5a ada2e9d 4489e5a 0d6b098 f833d09 e56f555 f833d09 4d6d2dc ac01208 18261f8 6402bfe 397b0a7 5e8b4c1 4d6d2dc a552026 3a2f9b3 5e8b4c1 4d6d2dc de099ae 5e8b4c1 765296c af61663 de099ae 4d6d2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import os
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt
MAX_PROMPT_TOKENS = 30
## info
model_info = {
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<s>[INST] {prompt} [/INST]',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
), # , load_in_8bit=True
'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<bos> {prompt}',
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
),
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
original_prompt_template='<s>[INST] {prompt} [/INST]',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
),
# 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
# tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
# model_type='llama', hf=True, ctransformers=True,
# original_prompt_template='<s>[INST] {prompt} [/INST]',
# interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
# )
}
suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:",
"Let me repeat the message:", "Sure, I'll summarize your message:"]
## functions
@spaces.GPU
def initialize_gpu():
pass
def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
token_btns = ([gr.Button(token, visible=True) for token in tokens]
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
progress_dummy_output = ''
return [progress_dummy_output, hidden_states, *token_btns]
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
num_beams=1):
interpreted_vectors = global_state[:, i]
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
# generation parameters
generation_kwargs = {
'max_new_tokens': int(max_new_tokens),
'do_sample': do_sample,
'temperature': temperature,
'top_k': int(top_k),
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'length_penalty': length_penalty,
'num_beams': int(num_beams)
}
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
# generate the interpretations
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
generation_texts = tokenizer.batch_decode(generated)
progress_dummy_output = ''
return ([progress_dummy_output] +
[gr.Textbox(text.replace('\n', ' '), visible=True, container=False) for text in generation_texts]
)
## main
torch.set_grad_enabled(False)
model_name = 'LLAMA2-7B'
# extract model info
model_args = deepcopy(model_info[model_name])
model_path = model_args.pop('model_path')
original_prompt_template = model_args.pop('original_prompt_template')
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
use_ctransformers = model_args.pop('ctransformers', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
# get model
model = AutoModelClass.from_pretrained(model_path, **model_args)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
# demo
json_output = gr.JSON()
css = '''
.bubble {
border: 2px solid #000;
border-radius: 10px;
padding: 10px;
margin-left: 5%;
width: 90%;
background: pink;
}
.bubble textarea {
border: none;
box-shadow: none;
background: inherit;
resize: none;
}
.explanation_accordion .svelte-s1r2yt{
font-weight: bold;
text-align: start;
}
'''
# '''
# .token_btn{
# background-color: none;
# background: none;
# border: none;
# padding: 0;
# font: inherit;
# cursor: pointer;
# color: blue; /* default text color */
# font-weight: bold;
# }
# .token_btn:hover {
# color: red;
# }
# '''
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
global_state = gr.State([])
with gr.Row():
with gr.Column(scale=5):
gr.Markdown('# π Self-Interpreting Models')
with gr.Accordion(
label='πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ',
elem_classes=['explanation_accordion']
):
gr.Markdown(
'''This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my own work!! π₯³) which was less mature but had the same idea in mind.
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
''', line_breaks=True)
with gr.Accordion(label='πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ',
elem_classes=['explanation_accordion']):
gr.Markdown(
'''If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace ``[X]`` *during computation* with the hidden state we want to understand,
we hope to get back a summary of the information that exists inside the hidden state, because it is encoded in a latent space the model uses itself!! How cool is that! π―π―π―
''', line_breaks=True)
with gr.Column(scale=1):
gr.Markdown('<span style="font-size:180px;">π€</span>')
# with gr.Group():
# with gr.Row():
# for txt in model_info.keys():
# btn = gr.Button(txt)
# model_btns.append(btn)
# for btn in model_btns:
# btn.click(reset_new_model, [global_state])
with gr.Blocks():
with gr.Tab('Memory Recall'):
pass
with gr.Tab('Physics Understanding'):
pass
with gr.Tab('Common Sense'):
pass
with gr.Tab('LLM Attacks'):
pass
with gr.Group():
original_prompt_raw = gr.Textbox(value='Should I eat cake or vegetables?', container=True, label='Original Prompt')
original_prompt_btn = gr.Button('Compute', variant='primary')
tokens_container = []
with gr.Row():
for i in range(MAX_PROMPT_TOKENS):
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
tokens_container.append(btn)
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
with gr.Group('Interpretation'):
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
with gr.Accordion(open=False, label='Settings'):
with gr.Row():
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
# num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
do_sample = gr.Checkbox(label='With sampling')
with gr.Accordion(label='Sampling Parameters'):
with gr.Row():
temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
top_p = gr.Slider(0., 1., value=0.95, label='top p')
with gr.Group('Output'):
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
for i in range(model.config.num_hidden_layers)]
for i, btn in enumerate(tokens_container):
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
num_tokens, do_sample, temperature,
top_k, top_p, repetition_penalty, length_penalty
], [progress_dummy, *interpretation_bubbles])
original_prompt_btn.click(get_hidden_states,
[original_prompt_raw],
[progress_dummy, global_state, *tokens_container])
demo.launch() |