dar-tau commited on
Commit
ce07d7a
·
verified ·
1 Parent(s): 8326344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -106,6 +106,8 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
106
  if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
107
  get_hidden_states(raw_original_prompt, force_hidden_states=True)
108
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
 
 
109
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
110
 
111
  # generation parameters
@@ -131,7 +133,7 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
131
  generation_texts = tokenizer.batch_decode(generated)
132
 
133
  # create GUI output
134
- important_idxs = 1 + interpreted_vectors.diff(dim=0).norm(dim=-1).topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
135
  print(f'{important_idxs=}')
136
  progress_dummy_output = ''
137
  elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
 
106
  if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
107
  get_hidden_states(raw_original_prompt, force_hidden_states=True)
108
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
109
+ hidden_means = torch.tensor(global_state.local_state.hidden_states.mean(dim=1)).to(model.device).to(model.dtype)
110
+ hidden_norms = hidden_means.norm(dim=-1)
111
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
112
 
113
  # generation parameters
 
133
  generation_texts = tokenizer.batch_decode(generated)
134
 
135
  # create GUI output
136
+ important_idxs = 1 + ((interpreted_vectors - hidden_means) / hidden_norms).diff(dim=0).norm(dim=-1).topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
137
  print(f'{important_idxs=}')
138
  progress_dummy_output = ''
139
  elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +