|
import os |
|
from copy import deepcopy |
|
from functools import partial |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
from datasets import load_dataset |
|
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from interpret import InterpretationPrompt |
|
|
|
MAX_PROMPT_TOKENS = 30 |
|
|
|
|
|
dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'}, |
|
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate', |
|
'filter': lambda x: x['label'] == 1}, |
|
] |
|
|
|
|
|
|
|
model_info = { |
|
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'], |
|
original_prompt_template='<s>[INST] {prompt} [/INST]', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
), |
|
|
|
'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'], |
|
original_prompt_template='<bos> {prompt}', |
|
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}', |
|
), |
|
|
|
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu', |
|
original_prompt_template='<s>[INST] {prompt} [/INST]', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:", |
|
"Let me repeat the message:", "Sure, I'll summarize your message:"] |
|
|
|
|
|
|
|
@spaces.GPU |
|
def initialize_gpu(): |
|
pass |
|
|
|
def get_hidden_states(raw_original_prompt, progress=gr.Progress()): |
|
original_prompt = original_prompt_template.format(prompt=raw_original_prompt) |
|
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
tokens = tokenizer.batch_decode(model_inputs.input_ids[0]) |
|
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True) |
|
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0) |
|
token_btns = ([gr.Button(token, visible=True) for token in tokens] |
|
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]) |
|
progress_dummy_output = '' |
|
return [progress_dummy_output, hidden_states, *token_btns] |
|
|
|
|
|
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample, |
|
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i, |
|
num_beams=1): |
|
|
|
interpreted_vectors = global_state[:, i] |
|
length_penalty = -length_penalty |
|
|
|
|
|
generation_kwargs = { |
|
'max_new_tokens': int(max_new_tokens), |
|
'do_sample': do_sample, |
|
'temperature': temperature, |
|
'top_k': int(top_k), |
|
'top_p': top_p, |
|
'repetition_penalty': repetition_penalty, |
|
'length_penalty': length_penalty, |
|
'num_beams': int(num_beams) |
|
} |
|
|
|
|
|
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5) |
|
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt) |
|
|
|
|
|
generate = spaces.GPU(interpretation_prompt.generate) if use_gpu else interpretation_prompt.generate |
|
generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs) |
|
generation_texts = tokenizer.batch_decode(generated) |
|
progress_dummy_output = '' |
|
return ([progress_dummy_output] + |
|
[gr.Textbox(text.replace('\n', ' '), visible=True, container=False) for text in generation_texts] |
|
) |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
model_name = 'LLAMA2-7B' |
|
|
|
|
|
model_args = deepcopy(model_info[model_name]) |
|
model_path = model_args.pop('model_path') |
|
original_prompt_template = model_args.pop('original_prompt_template') |
|
interpretation_prompt_template = model_args.pop('interpretation_prompt_template') |
|
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path |
|
use_ctransformers = model_args.pop('ctransformers', False) |
|
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM |
|
|
|
|
|
model = AutoModelClass.from_pretrained(model_path, **model_args) |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token']) |
|
|
|
|
|
json_output = gr.JSON() |
|
css = ''' |
|
|
|
.bubble { |
|
border: none |
|
border-radius: 10px; |
|
padding: 10px; |
|
margin-top: 15px; |
|
margin-left: 5%; |
|
width: 70%; |
|
box-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3); |
|
} |
|
|
|
.even_bubble{ |
|
background: pink; |
|
} |
|
|
|
.odd_bubble{ |
|
background: skyblue; |
|
} |
|
|
|
.bubble textarea { |
|
border: none; |
|
box-shadow: none; |
|
background: inherit; |
|
resize: none; |
|
} |
|
|
|
.explanation_accordion .svelte-s1r2yt{ |
|
font-weight: bold; |
|
text-align: start; |
|
} |
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt') |
|
|
|
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo: |
|
global_state = gr.State([]) |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
gr.Markdown('# π Self-Interpreting Models') |
|
with gr.Accordion( |
|
label='πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ', |
|
elem_classes=['explanation_accordion'] |
|
): |
|
gr.Markdown( |
|
'''This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). |
|
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work π₯³) which was less mature but had the same idea in mind. |
|
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! |
|
''', line_breaks=True) |
|
|
|
with gr.Accordion(label='πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ', |
|
elem_classes=['explanation_accordion']): |
|
gr.Markdown( |
|
'''According to the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers. |
|
So we can inject an representation from (roughly) any layer to any layer! If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand, |
|
we expect to get back a summary of the information that exists inside the hidden state. Since the model uses a roughly common latent space, it can understand representations from different layers and different runs!! How cool is that! π―π―π― |
|
''', line_breaks=True) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown('<span style="font-size:180px;">π€</span>') |
|
|
|
with gr.Group('Interpretation'): |
|
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt') |
|
|
|
gr.Markdown(''' |
|
Here are some examples of prompts we can analyze their internal representations: |
|
''') |
|
|
|
for info in dataset_info: |
|
with gr.Tab(info['name']): |
|
num_examples = 10 |
|
dataset = load_dataset(info['hf_repo'], split='train', streaming=True) |
|
if 'filter' in info: |
|
dataset = dataset.filter(info['filter']) |
|
dataset = dataset.shuffle(buffer_size=2000).take(num_examples) |
|
dataset = [[row[info['text_col']]] for row in dataset] |
|
gr.Examples(dataset, [original_prompt_raw]) |
|
|
|
with gr.Group(): |
|
original_prompt_raw.render() |
|
original_prompt_btn = gr.Button('Compute', variant='primary') |
|
|
|
tokens_container = [] |
|
with gr.Row(): |
|
for i in range(MAX_PROMPT_TOKENS): |
|
btn = gr.Button('', visible=False, elem_classes=['token_btn']) |
|
tokens_container.append(btn) |
|
use_gpu = gr.Checkbox(value=True, label='Use GPU') |
|
progress_dummy = gr.Markdown('', elem_id='progress_dummy') |
|
|
|
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble', |
|
'even_bubble' if i % 2 == 0 else 'odd_bubble']) |
|
for i in range(model.config.num_hidden_layers)] |
|
|
|
with gr.Accordion(open=False, label='Settings'): |
|
with gr.Row(): |
|
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens') |
|
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty') |
|
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty') |
|
|
|
do_sample = gr.Checkbox(label='With sampling') |
|
with gr.Accordion(label='Sampling Parameters'): |
|
with gr.Row(): |
|
temperature = gr.Slider(0., 5., value=0.6, label='Temperature') |
|
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k') |
|
top_p = gr.Slider(0., 1., value=0.95, label='top p') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, btn in enumerate(tokens_container): |
|
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, |
|
num_tokens, do_sample, temperature, |
|
top_k, top_p, repetition_penalty, length_penalty, |
|
use_gpu |
|
], [progress_dummy, *interpretation_bubbles]) |
|
|
|
original_prompt_btn.click(get_hidden_states, |
|
[original_prompt_raw], |
|
[progress_dummy, global_state, *tokens_container]) |
|
demo.launch() |