dar-tau commited on
Commit
11b86b4
1 Parent(s): b76e9de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -75,15 +75,15 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
75
  + [*extra_components])
76
 
77
 
78
- def get_hidden_states(raw_original_prompt):
79
  model, tokenizer = global_state.model, global_state.tokenizer
80
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
81
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
82
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
83
- outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
84
- if global_state.wait_with_hidden_states:
85
  global_state.local_state.hidden_states = None
86
  else:
 
87
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
88
  global_state.local_state.hidden_states = hidden_states.cpu().detach()
89
 
@@ -102,7 +102,7 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
102
  tokenizer = global_state.tokenizer
103
  print(f'run {model}')
104
  if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
105
- get_hidden_states(raw_original_prompt)
106
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
107
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
108
 
 
75
  + [*extra_components])
76
 
77
 
78
+ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
79
  model, tokenizer = global_state.model, global_state.tokenizer
80
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
81
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
82
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
83
+ if global_state.wait_with_hidden_states and not force_hidden_states:
 
84
  global_state.local_state.hidden_states = None
85
  else:
86
+ outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
87
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
88
  global_state.local_state.hidden_states = hidden_states.cpu().detach()
89
 
 
102
  tokenizer = global_state.tokenizer
103
  print(f'run {model}')
104
  if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
105
+ get_hidden_states(raw_original_prompt, force_hidden_states=True)
106
  interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
107
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
108