dar-tau commited on
Commit
9fa8328
1 Parent(s): 0023648

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -48
app.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
  from copy import deepcopy
3
  from functools import partial
4
  import spaces
@@ -6,7 +9,7 @@ import gradio as gr
6
  import torch
7
  from datasets import load_dataset
8
  from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from interpret import InterpretationPrompt
11
 
12
  MAX_PROMPT_TOKENS = 60
@@ -56,13 +59,43 @@ suggested_interpretation_prompts = [
56
  ]
57
 
58
 
 
 
 
 
 
 
 
 
 
59
  ## functions
60
  @spaces.GPU
61
  def initialize_gpu():
62
  pass
63
 
64
- def get_hidden_states(raw_original_prompt):
65
- original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
67
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
68
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
@@ -71,7 +104,8 @@ def get_hidden_states(raw_original_prompt):
71
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
72
  progress_dummy_output = ''
73
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
74
- return [progress_dummy_output, hidden_states, *token_btns, *invisible_bubbles]
 
75
 
76
 
77
  @spaces.GPU
@@ -79,7 +113,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
79
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
80
  num_beams=1):
81
 
82
- interpreted_vectors = global_state[:, i]
83
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
84
 
85
  # generation parameters
@@ -95,12 +129,12 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
95
  }
96
 
97
  # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
98
- interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
99
- interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
100
 
101
  # generate the interpretations
102
  # generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
103
- generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
104
  generation_texts = tokenizer.batch_decode(generated)
105
  progress_dummy_output = ''
106
  return ([progress_dummy_output] +
@@ -109,23 +143,9 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
109
 
110
 
111
  ## main
 
112
  torch.set_grad_enabled(False)
113
  model_name = 'LLAMA2-7B'
114
-
115
- # extract model info
116
- model_args = deepcopy(model_info[model_name])
117
- model_path = model_args.pop('model_path')
118
- original_prompt_template = model_args.pop('original_prompt_template')
119
- interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
120
- tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
121
- use_ctransformers = model_args.pop('ctransformers', False)
122
- AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
123
-
124
- # get model
125
- model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
126
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
127
-
128
- # demo
129
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
130
  tokens_container = []
131
  for i in range(MAX_PROMPT_TOKENS):
@@ -133,7 +153,8 @@ for i in range(MAX_PROMPT_TOKENS):
133
  tokens_container.append(btn)
134
 
135
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
136
- global_state = gr.State([])
 
137
  with gr.Row():
138
  with gr.Column(scale=5):
139
  gr.Markdown('# 😎 Self-Interpreting Models')
@@ -165,20 +186,17 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
165
 
166
  # with gr.Column(scale=1):
167
  # gr.Markdown('<span style="font-size:180px;">🤔</span>')
168
- gr.Markdown('''
169
- ## Choose Your Interpretation Prompt
170
- ''')
 
 
171
  with gr.Group('Interpretation'):
172
  interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
173
  gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False)
174
- # gr.Markdown('''
175
- # Here are some examples of prompts we can analyze their internal representations:
176
- # ''')
177
-
178
 
179
- gr.Markdown('''
180
- ## The Prompt to Analyze
181
- ''')
182
  for info in dataset_info:
183
  with gr.Tab(info['name']):
184
  num_examples = 10
@@ -187,7 +205,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
187
  dataset = dataset.filter(info['filter'])
188
  dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
189
  dataset = [[row[info['text_col']]] for row in dataset]
190
- gr.Examples(dataset, [original_prompt_raw], cache_examples=False)
191
 
192
  with gr.Group():
193
  original_prompt_raw.render()
@@ -198,6 +216,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
198
  with gr.Row():
199
  for btn in tokens_container:
200
  btn.render()
 
201
 
202
  with gr.Accordion(open=False, label='Generation Settings'):
203
  with gr.Row():
@@ -211,22 +230,15 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
211
  temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
212
  top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
213
  top_p = gr.Slider(0., 1., value=0.95, label='top p')
214
-
215
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
216
-
217
- interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble',
218
- 'even_bubble' if i % 2 == 0 else 'odd_bubble'])
219
- for i in range(model.config.num_hidden_layers)]
220
-
221
- # with gr.Group():
222
- # with gr.Row():
223
- # for txt in model_info.keys():
224
- # btn = gr.Button(txt)
225
- # model_btns.append(btn)
226
- # for btn in model_btns:
227
- # btn.click(reset_new_model, [global_state])
228
 
229
  # event listeners
 
 
230
  for i, btn in enumerate(tokens_container):
231
  btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
232
  num_tokens, do_sample, temperature,
 
1
  import os
2
+ import gc
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
  from copy import deepcopy
6
  from functools import partial
7
  import spaces
 
9
  import torch
10
  from datasets import load_dataset
11
  from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
13
  from interpret import InterpretationPrompt
14
 
15
  MAX_PROMPT_TOKENS = 60
 
59
  ]
60
 
61
 
62
+ @dataclass
63
+ class GlobalState:
64
+ tokenizer : Optional[PreTrainedTokenizer] = None
65
+ model : Optional[PreTrainedModel] = None
66
+ hidden_states : Optional[torch.Tensor] = None
67
+ interpretation_prompt_template : str = '{prompt}'
68
+ original_prompt_template : str = '{prompt}'
69
+
70
+
71
  ## functions
72
  @spaces.GPU
73
  def initialize_gpu():
74
  pass
75
 
76
+
77
+ def reset_model(model_name, global_state):
78
+ # extract model info
79
+ model_args = deepcopy(model_info[model_name])
80
+ model_path = model_args.pop('model_path')
81
+ global_state.original_prompt_template = model_args.pop('original_prompt_template')
82
+ global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
83
+ tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
84
+ use_ctransformers = model_args.pop('ctransformers', False)
85
+ AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
86
+
87
+ # get model
88
+ global_state.model, global_state.tokenizer, global_state.hidden_states = None, None, None
89
+ gc.collect()
90
+ global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
91
+ global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
92
+ gc.collect()
93
+ return global_state
94
+
95
+
96
+ def get_hidden_states(global_state, raw_original_prompt):
97
+ model, tokenizer = global_state.model, global_state.tokenizer
98
+ original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
99
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
100
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
101
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
 
104
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
105
  progress_dummy_output = ''
106
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
107
+ global_state.hidden_states = hidden_states
108
+ return [progress_dummy_output, global_state, *token_btns, *invisible_bubbles]
109
 
110
 
111
  @spaces.GPU
 
113
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
114
  num_beams=1):
115
 
116
+ interpreted_vectors = global_state.hidden_states[:, i]
117
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
118
 
119
  # generation parameters
 
129
  }
130
 
131
  # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
132
+ interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
133
+ interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
134
 
135
  # generate the interpretations
136
  # generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
137
+ generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
138
  generation_texts = tokenizer.batch_decode(generated)
139
  progress_dummy_output = ''
140
  return ([progress_dummy_output] +
 
143
 
144
 
145
  ## main
146
+
147
  torch.set_grad_enabled(False)
148
  model_name = 'LLAMA2-7B'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
150
  tokens_container = []
151
  for i in range(MAX_PROMPT_TOKENS):
 
153
  tokens_container.append(btn)
154
 
155
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
156
+ global_state = gr.State(reset_model(model_name, GlobalState()))
157
+
158
  with gr.Row():
159
  with gr.Column(scale=5):
160
  gr.Markdown('# 😎 Self-Interpreting Models')
 
186
 
187
  # with gr.Column(scale=1):
188
  # gr.Markdown('<span style="font-size:180px;">🤔</span>')
189
+
190
+ with gr.Group():
191
+ model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name)
192
+
193
+ gr.Markdown('## Choose Your Interpretation Prompt')
194
  with gr.Group('Interpretation'):
195
  interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
196
  gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False)
 
 
 
 
197
 
198
+
199
+ gr.Markdown('## The Prompt to Analyze')
 
200
  for info in dataset_info:
201
  with gr.Tab(info['name']):
202
  num_examples = 10
 
205
  dataset = dataset.filter(info['filter'])
206
  dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
207
  dataset = [[row[info['text_col']]] for row in dataset]
208
+ gr.Examples(dataset, [global_state, original_prompt_raw], cache_examples=False)
209
 
210
  with gr.Group():
211
  original_prompt_raw.render()
 
216
  with gr.Row():
217
  for btn in tokens_container:
218
  btn.render()
219
+
220
 
221
  with gr.Accordion(open=False, label='Generation Settings'):
222
  with gr.Row():
 
230
  temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
231
  top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
232
  top_p = gr.Slider(0., 1., value=0.95, label='top p')
233
+
234
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
235
+ interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
236
+ elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
237
+ ) for i in range(model.config.num_hidden_layers)]
 
 
 
 
 
 
 
 
 
238
 
239
  # event listeners
240
+ model_chooser.change(reset_new_model, [model_chooser, global_state], [global_state])
241
+
242
  for i, btn in enumerate(tokens_container):
243
  btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
244
  num_tokens, do_sample, temperature,