import os
import gc
from typing import Optional
from dataclasses import dataclass
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from datasets import load_dataset
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt
from configs import model_info, dataset_info
MAX_PROMPT_TOKENS = 60
MAX_NUM_LAYERS = 50
welcome_message = '**You are now running {model_name}!!** 🥳🥳🥳'
@dataclass
class LocalState:
hidden_states: Optional[torch.Tensor] = None
@dataclass
class GlobalState:
tokenizer : Optional[PreTrainedTokenizer] = None
model : Optional[PreTrainedModel] = None
local_state : LocalState = LocalState()
wait_with_hidden_state : bool = False
interpretation_prompt_template : str = '{prompt}'
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
layers_format : str = 'model.layers.{k}'
suggested_interpretation_prompts = [
"Sure, here's a bullet list of the key words in your message:",
"Sure, I'll summarize your message:",
"Sure, here are the words in your message:",
"Before responding, let me repeat the message you wrote:",
"Let me repeat the message:"
]
## functions
@spaces.GPU
def initialize_gpu():
pass
def reset_model(model_name, *extra_components, with_extra_components=True):
# extract model info
model_args = deepcopy(model_info[model_name])
model_path = model_args.pop('model_path')
global_state.original_prompt_template = model_args.pop('original_prompt_template')
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
global_state.layers_format = model_args.pop('layers_format')
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
use_ctransformers = model_args.pop('ctransformers', False)
dont_cuda = model_args.pop('dont_cuda', False)
global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
# get model
global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
gc.collect()
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
if not dont_cuda:
global_state.model.to('cuda')
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
gc.collect()
if with_extra_components:
return ([welcome_message.format(model_name=model_name)]
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
+ [*extra_components])
def get_hidden_states(raw_original_prompt, force_hidden_states=False):
model, tokenizer = global_state.model, global_state.tokenizer
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
if global_state.wait_with_hidden_states and not force_hidden_states:
global_state.local_state.hidden_states = None
else:
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
global_state.local_state.hidden_states = hidden_states.cpu().detach()
token_btns = ([gr.Button(token, visible=True) for token in tokens]
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
progress_dummy_output = ''
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
return [progress_dummy_output, *token_btns, *invisible_bubbles]
@spaces.GPU
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
num_beams=1):
model = global_state.model
tokenizer = global_state.tokenizer
print(f'run {model}')
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
get_hidden_states(raw_original_prompt, force_hidden_states=True)
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
# generation parameters
generation_kwargs = {
'max_new_tokens': int(max_new_tokens),
'do_sample': do_sample,
'temperature': temperature,
'top_k': int(top_k),
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'length_penalty': length_penalty,
'num_beams': int(num_beams)
}
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
# generate the interpretations
generated = interpretation_prompt.generate(model, {0: interpreted_vectors},
layers_format=global_state.layers_format, k=3,
**generation_kwargs)
generation_texts = tokenizer.batch_decode(generated)
progress_dummy_output = ''
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
bubble_outputs += [gr.Textbox('', visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
return [progress_dummy_output, *bubble_outputs]
## main
torch.set_grad_enabled(False)
global_state = GlobalState()
model_name = 'LLAMA2-7B'
reset_model(model_name, with_extra_components=False)
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
tokens_container = []
for i in range(MAX_PROMPT_TOKENS):
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
tokens_container.append(btn)
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
with gr.Row():
with gr.Column(scale=5):
gr.Markdown('# 😎 Self-Interpreting Models')
gr.Markdown('Model outputs are not filtered and might include undesired language!')
gr.Markdown(
'''
**👾 This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!👾**
This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
Honorary mention: **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work 🥳). It was less mature but had the same idea in mind. I think it can be a great introduction to the subject!
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
''', line_breaks=True)
gr.Markdown(
'''
**👾 The idea is really simple: models are able to understand their own hidden states by nature! 👾**
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers.
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand,
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! 😯😯😯
''', line_breaks=True)
# with gr.Column(scale=1):
# gr.Markdown('🤔')
with gr.Group():
model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
with gr.Blocks() as demo_blocks:
gr.Markdown('## The Prompt to Analyze')
for info in dataset_info:
with gr.Tab(info['name']):
num_examples = 10
dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
if 'filter' in info:
dataset = dataset.filter(info['filter'])
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
dataset = [[row[info['text_col']]] for row in dataset]
gr.Examples(dataset, [original_prompt_raw], cache_examples=False)
with gr.Group():
original_prompt_raw.render()
original_prompt_btn = gr.Button('Output Token List', variant='primary')
gr.Markdown('## Choose Your Interpretation Prompt')
with gr.Group('Interpretation'):
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
[interpretation_prompt], cache_examples=False)
with gr.Accordion(open=False, label='Generation Settings'):
with gr.Row():
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
# num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
do_sample = gr.Checkbox(label='With sampling')
with gr.Accordion(label='Sampling Parameters'):
with gr.Row():
temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
top_p = gr.Slider(0., 1., value=0.95, label='top p')
gr.Markdown('''
## Tokens
### Here go the tokens of the prompt (click on the one to explore)
''')
with gr.Row():
for btn in tokens_container:
btn.render()
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
) for i in range(MAX_NUM_LAYERS)]
# event listeners
for i, btn in enumerate(tokens_container):
btn.click(partial(run_interpretation, i=i), [original_prompt_raw, interpretation_prompt,
num_tokens, do_sample, temperature,
top_k, top_p, repetition_penalty, length_penalty
], [progress_dummy, *interpretation_bubbles])
original_prompt_btn.click(get_hidden_states,
[original_prompt_raw],
[progress_dummy, *tokens_container, *interpretation_bubbles])
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
model_chooser.change(reset_model, [model_chooser, *extra_components],
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
demo.launch()