import os from copy import deepcopy from functools import partial import spaces import gradio as gr import torch from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer from interpret import InterpretationPrompt MAX_PROMPT_TOKENS = 30 ## info model_info = { 'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'], original_prompt_template='[INST] {prompt} [/INST]', interpretation_prompt_template='[INST] [X] [/INST] {prompt}', ), # , load_in_8bit=True 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'], original_prompt_template=' {prompt}', interpretation_prompt_template='User: [X]\n\nAnswer: {prompt}', ), 'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu', original_prompt_template='[INST] {prompt} [/INST]', interpretation_prompt_template='[INST] [X] [/INST] {prompt}', ), # 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf', # tokenizer='mistralai/Mistral-7B-Instruct-v0.2', # model_type='llama', hf=True, ctransformers=True, # original_prompt_template='[INST] {prompt} [/INST]', # interpretation_prompt_template='[INST] [X] [/INST] {prompt}', # ) } suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:", "Let me repeat the message:", "Sure, I'll summarize your message:"] ## functions @spaces.GPU def initialize_gpu(): pass def get_hidden_states(raw_original_prompt, progress=gr.Progress()): original_prompt = original_prompt_template.format(prompt=raw_original_prompt) model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device) tokens = tokenizer.batch_decode(model_inputs.input_ids[0]) outputs = model(**model_inputs, output_hidden_states=True, return_dict=True) hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0) token_btns = ([gr.Button(token, visible=True) for token in tokens] + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]) progress_dummy_output = '' return [progress_dummy_output, hidden_states, *token_btns] def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample, temperature, top_k, top_p, repetition_penalty, length_penalty, i, num_beams=1): interpreted_vectors = global_state[:, i] length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it # generation parameters generation_kwargs = { 'max_new_tokens': int(max_new_tokens), 'do_sample': do_sample, 'temperature': temperature, 'top_k': int(top_k), 'top_p': top_p, 'repetition_penalty': repetition_penalty, 'length_penalty': length_penalty, 'num_beams': int(num_beams) } # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template) interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt) interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt) # generate the interpretations generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs) generation_texts = tokenizer.batch_decode(generated) progress_dummy_output = '' return ([progress_dummy_output] + [gr.Textbox(text.replace('\n', ' '), visible=True, container=False) for text in generation_texts] ) ## main torch.set_grad_enabled(False) model_name = 'LLAMA2-7B' # extract model info model_args = deepcopy(model_info[model_name]) model_path = model_args.pop('model_path') original_prompt_template = model_args.pop('original_prompt_template') interpretation_prompt_template = model_args.pop('interpretation_prompt_template') tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path use_ctransformers = model_args.pop('ctransformers', False) AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM # get model model = AutoModelClass.from_pretrained(model_path, **model_args) tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token']) # demo json_output = gr.JSON() css = ''' .bubble { border: 2px solid #000; border-radius: 10px; padding: 10px; margin-left: 5%; width: 90%; background: pink; } .bubble textarea { border: none; box-shadow: none; background: inherit; resize: none; } .explanation_accordion .svelte-s1r2yt{ font-weight: bold; text-align: start; } ''' # ''' # .token_btn{ # background-color: none; # background: none; # border: none; # padding: 0; # font: inherit; # cursor: pointer; # color: blue; /* default text color */ # font-weight: bold; # } # .token_btn:hover { # color: red; # } # ''' with gr.Blocks(theme=gr.themes.Default(), css=css) as demo: global_state = gr.State([]) with gr.Row(): with gr.Column(scale=5): gr.Markdown('# 😎 Self-Interpreting Models') with gr.Accordion( label='👾 This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!👾', elem_classes=['explanation_accordion'] ): gr.Markdown( '''This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my own work!! 🥳) which was less mature but had the same idea in mind. We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! ''', line_breaks=True) with gr.Accordion(label='👾 The idea is really simple: models are able to understand their own hidden states by nature! 👾', elem_classes=['explanation_accordion']): gr.Markdown( '''If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace ``[X]`` *during computation* with the hidden state we want to understand, we hope to get back a summary of the information that exists inside the hidden state, because it is encoded in a latent space the model uses itself!! How cool is that! 😯😯😯 ''', line_breaks=True) with gr.Column(scale=1): gr.Markdown('🤔') # with gr.Group(): # with gr.Row(): # for txt in model_info.keys(): # btn = gr.Button(txt) # model_btns.append(btn) # for btn in model_btns: # btn.click(reset_new_model, [global_state]) with gr.Blocks(): with gr.Tab('Memory Recall'): pass with gr.Tab('Physics Understanding'): pass with gr.Tab('Common Sense'): pass with gr.Tab('LLM Attacks'): pass with gr.Group(): original_prompt_raw = gr.Textbox(value='Should I eat cake or vegetables?', container=True, label='Original Prompt') original_prompt_btn = gr.Button('Compute', variant='primary') tokens_container = [] with gr.Row(): for i in range(MAX_PROMPT_TOKENS): btn = gr.Button('', visible=False, elem_classes=['token_btn']) tokens_container.append(btn) progress_dummy = gr.Markdown('', elem_id='progress_dummy') with gr.Group('Interpretation'): interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt') with gr.Accordion(open=False, label='Settings'): with gr.Row(): num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens') repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty') length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty') # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams') do_sample = gr.Checkbox(label='With sampling') with gr.Accordion(label='Sampling Parameters'): with gr.Row(): temperature = gr.Slider(0., 5., value=0.6, label='Temperature') top_k = gr.Slider(1, 1000, value=50, step=1, label='top k') top_p = gr.Slider(0., 1., value=0.95, label='top p') with gr.Group('Output'): interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble']) for i in range(model.config.num_hidden_layers)] for i, btn in enumerate(tokens_container): btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature, top_k, top_p, repetition_penalty, length_penalty ], [progress_dummy, *interpretation_bubbles]) original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [progress_dummy, global_state, *tokens_container]) demo.launch()