dar-tau commited on
Commit
f8fba1a
1 Parent(s): c23388b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -27,6 +27,7 @@ class GlobalState:
27
  tokenizer : Optional[PreTrainedTokenizer] = None
28
  model : Optional[PreTrainedModel] = None
29
  local_state : LocalState = LocalState()
 
30
  interpretation_prompt_template : str = '{prompt}'
31
  original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
32
  layers_format : str = 'model.layers.{k}'
@@ -48,7 +49,6 @@ def initialize_gpu():
48
 
49
  def reset_model(model_name, *extra_components, with_extra_components=True):
50
  # extract model info
51
-
52
  model_args = deepcopy(model_info[model_name])
53
  model_path = model_args.pop('model_path')
54
  global_state.original_prompt_template = model_args.pop('original_prompt_template')
@@ -57,6 +57,7 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
57
  tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
58
  use_ctransformers = model_args.pop('ctransformers', False)
59
  dont_cuda = model_args.pop('dont_cuda', False)
 
60
  AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
61
 
62
  # get model
@@ -80,22 +81,28 @@ def get_hidden_states(raw_original_prompt):
80
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
81
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
82
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
83
- hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
 
 
 
 
 
84
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
85
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
86
  progress_dummy_output = ''
87
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
88
- global_state.local_state.hidden_states = hidden_states.cpu().detach()
89
  return [progress_dummy_output, *token_btns, *invisible_bubbles]
90
 
91
 
92
  @spaces.GPU
93
- def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
94
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
95
  num_beams=1):
96
  model = global_state.model
97
  tokenizer = global_state.tokenizer
98
  print(f'run {model}')
 
 
99
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
100
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
101
 
@@ -218,7 +225,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
218
 
219
  # event listeners
220
  for i, btn in enumerate(tokens_container):
221
- btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
222
  num_tokens, do_sample, temperature,
223
  top_k, top_p, repetition_penalty, length_penalty
224
  ], [progress_dummy, *interpretation_bubbles])
 
27
  tokenizer : Optional[PreTrainedTokenizer] = None
28
  model : Optional[PreTrainedModel] = None
29
  local_state : LocalState = LocalState()
30
+ wait_with_hidden_state : bool = False
31
  interpretation_prompt_template : str = '{prompt}'
32
  original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
33
  layers_format : str = 'model.layers.{k}'
 
49
 
50
  def reset_model(model_name, *extra_components, with_extra_components=True):
51
  # extract model info
 
52
  model_args = deepcopy(model_info[model_name])
53
  model_path = model_args.pop('model_path')
54
  global_state.original_prompt_template = model_args.pop('original_prompt_template')
 
57
  tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
58
  use_ctransformers = model_args.pop('ctransformers', False)
59
  dont_cuda = model_args.pop('dont_cuda', False)
60
+ global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False)
61
  AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
62
 
63
  # get model
 
81
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
82
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
83
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
84
+ if global_state.wait_with_hidden_states:
85
+ global_state.local_state.hidden_states = None
86
+ else:
87
+ hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
88
+ global_state.local_state.hidden_states = hidden_states.cpu().detach()
89
+
90
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
91
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
92
  progress_dummy_output = ''
93
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
 
94
  return [progress_dummy_output, *token_btns, *invisible_bubbles]
95
 
96
 
97
  @spaces.GPU
98
+ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
99
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
100
  num_beams=1):
101
  model = global_state.model
102
  tokenizer = global_state.tokenizer
103
  print(f'run {model}')
104
+ if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
105
+ get_hidden_states(raw_original_prompt)
106
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
107
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
108
 
 
225
 
226
  # event listeners
227
  for i, btn in enumerate(tokens_container):
228
+ btn.click(partial(run_interpretation, i=i), [original_prompt_raw, interpretation_prompt,
229
  num_tokens, do_sample, temperature,
230
  top_k, top_p, repetition_penalty, length_penalty
231
  ], [progress_dummy, *interpretation_bubbles])