File size: 13,542 Bytes
4d6d2dc
9fa8328
 
 
4d6d2dc
9b5c8c6
 
4d6d2dc
 
2533b7f
4d6d2dc
9fa8328
077e2b3
4d6d2dc
98858d4
4d6d2dc
bc2f9ff
4d6d2dc
6a634a2
 
63981db
fe6d32c
fb3b2a1
04aee9f
fe6d32c
b30e55a
 
4d6d2dc
f2d60cb
ce4a0c3
4d6d2dc
 
 
f2d60cb
ce4a0c3
4d6d2dc
 
 
f2d60cb
ce4a0c3
4d6d2dc
 
 
f2d60cb
 
 
 
 
 
4d6d2dc
 
 
3a7cf2a
 
 
 
bf7c081
3a7cf2a
 
4d6d2dc
 
9fa8328
 
 
 
 
 
 
 
 
4d6d2dc
1a614f9
 
 
 
9fa8328
049eed9
9fa8328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1cea83
9fa8328
 
4d6d2dc
5daf90b
4d6d2dc
 
b30a06e
 
de099ae
cf4e80d
9fa8328
e1cea83
d8c5a8d
 
20c0832
e1cea83
bc2f9ff
de099ae
b30a06e
9fa8328
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fa8328
 
4d6d2dc
 
bc2f9ff
9fa8328
4d6d2dc
de099ae
518abab
5ae57a2
518abab
4d6d2dc
 
 
 
049eed9
03ac798
3e36699
049eed9
7f2e668
cf4e80d
049eed9
d028e6b
 
 
b30e55a
e036b8e
9fa8328
4d6d2dc
 
0d6b098
5ae57a2
 
 
b29377d
 
 
 
9f98ca2
b29377d
 
 
9f98ca2
 
 
 
868605b
 
 
9f98ca2
868605b
 
cf4e80d
 
 
9f98ca2
6b4003c
868605b
 
9fa8328
 
 
 
 
643b640
 
87f00d7
b3fa350
9fa8328
 
b3fa350
 
 
 
 
 
 
bfae66e
9fa8328
b3fa350
4d6d2dc
b30e55a
bc2f9ff
5e8b4c1
fd25f6d
cf4e80d
5e8b4c1
cf4e80d
 
9fa8328
cf4e80d
3a7cf2a
4d6d2dc
 
 
 
 
 
 
 
 
 
 
9fa8328
cf4e80d
9fa8328
 
 
643b640
e1cea83
dafad0d
e1cea83
9fa8328
5e8b4c1
bc2f9ff
5e8b4c1
63981db
5e8b4c1
 
765296c
af61663
e1cea83
0023648
4d6d2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import gc
from typing import Optional
from dataclasses import dataclass
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from datasets import load_dataset
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt

MAX_PROMPT_TOKENS = 60


## info
dataset_info = [
                {'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
                {'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate', 
                 'filter': lambda x: x['label'] == 1},
                # {'name': 'Physical Understanding', 'hf_repo': 'piqa', 'text_col': 'goal'},
                {'name': 'Social Reasoning', 'hf_repo': 'ProlificAI/social-reasoning-rlhf', 'text_col': 'question'}
               ]


model_info = {
    'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'], 
                                          original_prompt_template='<s>{prompt}',
                                          interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                         ), # , load_in_8bit=True
    
    'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
                            original_prompt_template='<bos>{prompt}',
                            interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
                           ),
    
    'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu', 
                                               original_prompt_template='<s>{prompt}',
                                               interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                              ),
    
    # 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf', 
    #                                                tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
    #                                                model_type='llama', hf=True, ctransformers=True,
    #                                                original_prompt_template='<s>[INST] {prompt} [/INST]',
    #                                                interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
    #                                               )
        }


suggested_interpretation_prompts = [
                                    "Sure, here's a bullet list of the key words in your message:",
                                    "Sure, I'll summarize your message:", 
                                    "Sure, here are the words in your message:",
                                    "Before responding, let me repeat the message you wrote:", 
                                    "Let me repeat the message:"
                                   ]


@dataclass
class GlobalState:
    tokenizer : Optional[PreTrainedTokenizer] = None
    model : Optional[PreTrainedModel] = None
    hidden_states : Optional[torch.Tensor] = None
    interpretation_prompt_template : str = '{prompt}'
    original_prompt_template : str = '{prompt}'

    
## functions
@spaces.GPU
def initialize_gpu():
    pass


def reset_model(model_name): 
    # extract model info
    model_args = deepcopy(model_info[model_name])
    model_path = model_args.pop('model_path')
    global_state.original_prompt_template = model_args.pop('original_prompt_template')
    global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
    tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
    use_ctransformers = model_args.pop('ctransformers', False)
    AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
    
    # get model
    global_state.model, global_state.tokenizer, global_state.hidden_states = None, None, None
    gc.collect()
    global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
    global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
    gc.collect()
    

def get_hidden_states(raw_original_prompt):
    model, tokenizer = global_state.model, global_state.tokenizer
    original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
    model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
    tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
    outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
    hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
    token_btns = ([gr.Button(token, visible=True) for token in tokens] 
                  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
    progress_dummy_output = ''
    invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
    global_state.hidden_states = hidden_states
    return [progress_dummy_output, *token_btns, *invisible_bubbles]


@spaces.GPU
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample, 
                       temperature, top_k, top_p, repetition_penalty, length_penalty, i, 
                       num_beams=1):

    interpreted_vectors = global_state.hidden_states[:, i]
    length_penalty = -length_penalty   # unintuitively, length_penalty > 0 will make sequences longer, so we negate it

    # generation parameters
    generation_kwargs = {
        'max_new_tokens': int(max_new_tokens),
        'do_sample': do_sample,
        'temperature': temperature,
        'top_k': int(top_k),
        'top_p': top_p,
        'repetition_penalty': repetition_penalty,
        'length_penalty': length_penalty,
        'num_beams': int(num_beams)
    }
    
    # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
    interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
    interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)

    # generate the interpretations
    # generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
    generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
    generation_texts = tokenizer.batch_decode(generated)
    progress_dummy_output = ''
    return ([progress_dummy_output] + 
            [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
           )


## main
torch.set_grad_enabled(False)
global_state = GlobalState()

model_name = 'LLAMA2-7B'
reset_model(model_name)
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
tokens_container = []

for i in range(MAX_PROMPT_TOKENS):
    btn = gr.Button('', visible=False, elem_classes=['token_btn'])
    tokens_container.append(btn)

with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
    
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown('# 😎 Self-Interpreting Models')

            gr.Markdown('<b style="color: #8B0000;">Model outputs are not filtered and might include undesired language!</b>')
            
            # gr.Markdown(
            #     '**πŸ‘Ύ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πŸ‘Ύ**',
            #     # elem_classes=['explanation_accordion']
            # )
            gr.Markdown(
            '''
                **πŸ‘Ύ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πŸ‘Ύ**
                This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). 
                An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work πŸ₯³) which was less mature but had the same idea in mind. 
                We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! 
            ''', line_breaks=True)
            
            # gr.Markdown('**πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**',
            #               # elem_classes=['explanation_accordion']
            #             )  
            gr.Markdown(
            '''
            **πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**
            In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers. 
            So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand, 
            we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! 😯😯😯
            ''', line_breaks=True)

        # with gr.Column(scale=1):    
        #     gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')

        with gr.Group():
            model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name)

    gr.Markdown('## Choose Your Interpretation Prompt')
    with gr.Group('Interpretation'):
        interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
        gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False)


    gr.Markdown('## The Prompt to Analyze')        
    for info in dataset_info:
        with gr.Tab(info['name']):
            num_examples = 10
            dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
            if 'filter' in info:
                dataset = dataset.filter(info['filter'])
            dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
            dataset = [[row[info['text_col']]] for row in dataset]
            gr.Examples(dataset, [global_state, original_prompt_raw], cache_examples=False)
            
    with gr.Group():
        original_prompt_raw.render()
        original_prompt_btn = gr.Button('Output Token List', variant='primary')

    gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
    
    with gr.Row():
        for btn in tokens_container:
            btn.render()

    
    with gr.Accordion(open=False, label='Generation Settings'):
        with gr.Row():
            num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
            repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
            length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
            # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
        do_sample = gr.Checkbox(label='With sampling')
        with gr.Accordion(label='Sampling Parameters'):
            with gr.Row():
                temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
                top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
                top_p = gr.Slider(0., 1., value=0.95, label='top p')

    progress_dummy = gr.Markdown('', elem_id='progress_dummy')
    interpretation_bubbles = [gr.Textbox('', container=False, visible=False, 
                                         elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
                                        ) for i in range(model.config.num_hidden_layers)]

    
    # event listeners
    model_chooser.change(reset_new_model, [model_chooser], [])
    
    for i, btn in enumerate(tokens_container):
        btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, 
                                                     num_tokens, do_sample, temperature, 
                                                     top_k, top_p, repetition_penalty, length_penalty,
                                                    ], [progress_dummy, *interpretation_bubbles])
    
    original_prompt_btn.click(get_hidden_states, 
                              [original_prompt_raw], 
                              [progress_dummy, *tokens_container, *interpretation_bubbles])
    original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
    demo.launch()