Update app.py
Browse files
app.py
CHANGED
@@ -75,15 +75,15 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
75 |
+ [*extra_components])
|
76 |
|
77 |
|
78 |
-
def get_hidden_states(raw_original_prompt):
|
79 |
model, tokenizer = global_state.model, global_state.tokenizer
|
80 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
82 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
83 |
-
|
84 |
-
if global_state.wait_with_hidden_states:
|
85 |
global_state.local_state.hidden_states = None
|
86 |
else:
|
|
|
87 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
88 |
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
89 |
|
@@ -102,7 +102,7 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
102 |
tokenizer = global_state.tokenizer
|
103 |
print(f'run {model}')
|
104 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
105 |
-
get_hidden_states(raw_original_prompt)
|
106 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
107 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
108 |
|
|
|
75 |
+ [*extra_components])
|
76 |
|
77 |
|
78 |
+
def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
79 |
model, tokenizer = global_state.model, global_state.tokenizer
|
80 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
82 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
83 |
+
if global_state.wait_with_hidden_states and not force_hidden_states:
|
|
|
84 |
global_state.local_state.hidden_states = None
|
85 |
else:
|
86 |
+
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
87 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
88 |
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
89 |
|
|
|
102 |
tokenizer = global_state.tokenizer
|
103 |
print(f'run {model}')
|
104 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
105 |
+
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
106 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
107 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
108 |
|