|
import os |
|
import gc |
|
from typing import Optional |
|
from dataclasses import dataclass |
|
from copy import deepcopy |
|
from functools import partial |
|
import numpy as np |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from datasets import load_dataset |
|
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM |
|
from interpret import InterpretationPrompt |
|
from configs import model_info, dataset_info |
|
|
|
|
|
MAX_PROMPT_TOKENS = 60 |
|
MAX_NUM_LAYERS = 50 |
|
welcome_message = '**You are now running {model_name}!!** π₯³π₯³π₯³' |
|
|
|
@dataclass |
|
class LocalState: |
|
hidden_states: Optional[torch.Tensor] = None |
|
|
|
@dataclass |
|
class GlobalState: |
|
tokenizer : Optional[PreTrainedTokenizer] = None |
|
model : Optional[PreTrainedModel] = None |
|
sentence_transformer: Optional[PreTrainedModel] = None |
|
local_state : LocalState = LocalState() |
|
wait_with_hidden_state : bool = False |
|
interpretation_prompt_template : str = '{prompt}' |
|
original_prompt_template : str = 'User: [X]\n\nAssistant: {prompt}' |
|
layers_format : str = 'model.layers.{k}' |
|
|
|
|
|
suggested_interpretation_prompts = [ |
|
"Sure, I'll summarize your message:", |
|
"The meaning of [X] is", |
|
"Sure, here's a bullet list of the key words in your message:", |
|
"Sure, here are the words in your message:", |
|
"Before responding, let me repeat the message you wrote:", |
|
"Let me repeat the message:" |
|
] |
|
|
|
|
|
|
|
@spaces.GPU |
|
def initialize_gpu(): |
|
pass |
|
|
|
def reset_model(model_name, *extra_components, reset_sentence_transformer=False, with_extra_components=True): |
|
|
|
model_args = deepcopy(model_info[model_name]) |
|
model_path = model_args.pop('model_path') |
|
global_state.original_prompt_template = model_args.pop('original_prompt_template') |
|
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template') |
|
global_state.layers_format = model_args.pop('layers_format') |
|
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path |
|
use_ctransformers = model_args.pop('ctransformers', False) |
|
dont_cuda = model_args.pop('dont_cuda', False) |
|
global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False) |
|
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM |
|
|
|
|
|
global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None |
|
gc.collect() |
|
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args) |
|
if reset_sentence_transformer: |
|
global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') |
|
gc.collect() |
|
if not dont_cuda: |
|
global_state.model.to('cuda') |
|
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token']) |
|
gc.collect() |
|
if with_extra_components: |
|
return ([welcome_message.format(model_name=model_name)] |
|
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))] |
|
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))] |
|
+ [*extra_components]) |
|
|
|
|
|
def get_hidden_states(raw_original_prompt, force_hidden_states=False): |
|
model, tokenizer = global_state.model, global_state.tokenizer |
|
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt) |
|
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
tokens = tokenizer.batch_decode(model_inputs.input_ids[0]) |
|
if global_state.wait_with_hidden_states and not force_hidden_states: |
|
global_state.local_state.hidden_states = None |
|
else: |
|
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True) |
|
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0) |
|
global_state.local_state.hidden_states = hidden_states.cpu().detach() |
|
|
|
token_btns = ([gr.Button(token, visible=True) for token in tokens] |
|
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]) |
|
progress_dummy_output = '' |
|
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)] |
|
return [progress_dummy_output, *token_btns, *invisible_bubbles] |
|
|
|
|
|
@spaces.GPU |
|
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample, |
|
temperature, top_k, top_p, repetition_penalty, length_penalty, i, |
|
num_beams=1): |
|
model = global_state.model |
|
tokenizer = global_state.tokenizer |
|
print(f'run {model}') |
|
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None: |
|
get_hidden_states(raw_original_prompt, force_hidden_states=True) |
|
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype) |
|
hidden_means = torch.tensor(global_state.local_state.hidden_states.mean(dim=1)).to(model.device).to(model.dtype) |
|
length_penalty = -length_penalty |
|
|
|
|
|
generation_kwargs = { |
|
'max_new_tokens': int(max_new_tokens), |
|
'do_sample': do_sample, |
|
'temperature': temperature, |
|
'top_k': int(top_k), |
|
'top_p': top_p, |
|
'repetition_penalty': repetition_penalty, |
|
'length_penalty': length_penalty, |
|
'num_beams': int(num_beams) |
|
} |
|
|
|
|
|
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5) |
|
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt) |
|
|
|
|
|
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, |
|
layers_format=global_state.layers_format, k=3, |
|
**generation_kwargs) |
|
generation_texts = tokenizer.batch_decode(generated) |
|
|
|
|
|
vectors_to_compare = interpreted_vectors |
|
avoid_first, avoid_last = 2, 1 |
|
vectors_to_compare = vectors_to_compare[avoid_first:-avoid_last] |
|
diff_score = F.normalize(vectors_to_compare, dim=-1).diff(dim=0).norm(dim=-1) |
|
important_idxs = avoid_first + diff_score.topk(k=int(np.ceil(0.1 * len(generation_texts)))).indices.cpu().numpy() |
|
|
|
|
|
print(f'{important_idxs=}') |
|
progress_dummy_output = '' |
|
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] + |
|
([] if i in important_idxs else ['faded_bubble']) for i in range(len(generation_texts))] |
|
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), show_label=True, visible=True, |
|
container=True, label=f'Layer {i}', elem_classes=elem_classes[i]) |
|
for i, text in enumerate(generation_texts)] |
|
bubble_outputs += [gr.Textbox('', visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))] |
|
return [progress_dummy_output, *bubble_outputs] |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
global_state = GlobalState() |
|
|
|
model_name = 'LLAMA2-7B' |
|
reset_model(model_name, with_extra_components=False, reset_sentence_transformer=True) |
|
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt') |
|
tokens_container = [] |
|
|
|
for i in range(MAX_PROMPT_TOKENS): |
|
btn = gr.Button('', visible=False, elem_classes=['token_btn']) |
|
tokens_container.append(btn) |
|
|
|
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
gr.Markdown('# π Self-Interpreting Models') |
|
|
|
gr.Markdown('<b style="color: #8B0000;">Model outputs are not filtered and might include undesired language!</b>') |
|
|
|
gr.Markdown( |
|
''' |
|
**πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ** |
|
This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). |
|
Honorary mention: **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work π₯³). It was less mature but had the same idea in mind. I think it can be a great introduction to the subject! |
|
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! |
|
''', line_breaks=True) |
|
|
|
gr.Markdown( |
|
''' |
|
**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ** |
|
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers. |
|
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure, I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand, |
|
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! π―π―π― |
|
''', line_breaks=True) |
|
|
|
|
|
|
|
|
|
with gr.Group(): |
|
model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name) |
|
welcome_model = gr.Markdown(welcome_message.format(model_name=model_name)) |
|
with gr.Blocks() as demo_main: |
|
gr.Markdown('## The Prompt to Analyze') |
|
for info in dataset_info: |
|
with gr.Tab(info['name']): |
|
num_examples = 10 |
|
dataset = load_dataset(info['hf_repo'], split='train', streaming=True) |
|
if 'filter' in info: |
|
dataset = dataset.filter(info['filter']) |
|
dataset = dataset.shuffle(buffer_size=2000).take(num_examples) |
|
dataset = [[row[info['text_col']]] for row in dataset] |
|
gr.Examples(dataset, [raw_original_prompt], cache_examples=False) |
|
|
|
with gr.Group(): |
|
raw_original_prompt.render() |
|
original_prompt_btn = gr.Button('Output Token List', variant='primary') |
|
|
|
gr.Markdown('## Choose Your Interpretation Prompt') |
|
with gr.Group('Interpretation'): |
|
raw_interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt') |
|
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts], |
|
[raw_interpretation_prompt], cache_examples=False) |
|
|
|
with gr.Accordion(open=False, label='Generation Settings'): |
|
with gr.Row(): |
|
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens') |
|
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty') |
|
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty') |
|
|
|
do_sample = gr.Checkbox(label='With sampling') |
|
with gr.Accordion(label='Sampling Parameters'): |
|
with gr.Row(): |
|
temperature = gr.Slider(0., 5., value=0.6, label='Temperature') |
|
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k') |
|
top_p = gr.Slider(0., 1., value=0.95, label='top p') |
|
|
|
gr.Markdown(''' |
|
## Tokens |
|
### Here go the tokens of the prompt (click on the one to explore) |
|
''') |
|
with gr.Row(): |
|
for btn in tokens_container: |
|
btn.render() |
|
|
|
progress_dummy = gr.Markdown('', elem_id='progress_dummy') |
|
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) |
|
for i in range(MAX_NUM_LAYERS)] |
|
|
|
|
|
for i, btn in enumerate(tokens_container): |
|
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt, |
|
num_tokens, do_sample, temperature, |
|
top_k, top_p, repetition_penalty, length_penalty |
|
], [progress_dummy, *interpretation_bubbles]) |
|
|
|
original_prompt_btn.click(get_hidden_states, |
|
[raw_original_prompt], |
|
[progress_dummy, *tokens_container, *interpretation_bubbles]) |
|
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container) |
|
|
|
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn] |
|
model_chooser.change(reset_model, [model_chooser, *extra_components], |
|
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components]) |
|
|
|
demo.launch() |