Update app.py
Browse files
app.py
CHANGED
@@ -27,6 +27,7 @@ class GlobalState:
|
|
27 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
28 |
model : Optional[PreTrainedModel] = None
|
29 |
local_state : LocalState = LocalState()
|
|
|
30 |
interpretation_prompt_template : str = '{prompt}'
|
31 |
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
|
32 |
layers_format : str = 'model.layers.{k}'
|
@@ -48,7 +49,6 @@ def initialize_gpu():
|
|
48 |
|
49 |
def reset_model(model_name, *extra_components, with_extra_components=True):
|
50 |
# extract model info
|
51 |
-
|
52 |
model_args = deepcopy(model_info[model_name])
|
53 |
model_path = model_args.pop('model_path')
|
54 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
@@ -57,6 +57,7 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
57 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
58 |
use_ctransformers = model_args.pop('ctransformers', False)
|
59 |
dont_cuda = model_args.pop('dont_cuda', False)
|
|
|
60 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
61 |
|
62 |
# get model
|
@@ -80,22 +81,28 @@ def get_hidden_states(raw_original_prompt):
|
|
80 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
81 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
82 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
84 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
85 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
86 |
progress_dummy_output = ''
|
87 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
88 |
-
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
89 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
90 |
|
91 |
|
92 |
@spaces.GPU
|
93 |
-
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
94 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
95 |
num_beams=1):
|
96 |
model = global_state.model
|
97 |
tokenizer = global_state.tokenizer
|
98 |
print(f'run {model}')
|
|
|
|
|
99 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
100 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
101 |
|
@@ -218,7 +225,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
218 |
|
219 |
# event listeners
|
220 |
for i, btn in enumerate(tokens_container):
|
221 |
-
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
222 |
num_tokens, do_sample, temperature,
|
223 |
top_k, top_p, repetition_penalty, length_penalty
|
224 |
], [progress_dummy, *interpretation_bubbles])
|
|
|
27 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
28 |
model : Optional[PreTrainedModel] = None
|
29 |
local_state : LocalState = LocalState()
|
30 |
+
wait_with_hidden_state : bool = False
|
31 |
interpretation_prompt_template : str = '{prompt}'
|
32 |
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
|
33 |
layers_format : str = 'model.layers.{k}'
|
|
|
49 |
|
50 |
def reset_model(model_name, *extra_components, with_extra_components=True):
|
51 |
# extract model info
|
|
|
52 |
model_args = deepcopy(model_info[model_name])
|
53 |
model_path = model_args.pop('model_path')
|
54 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
|
|
57 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
58 |
use_ctransformers = model_args.pop('ctransformers', False)
|
59 |
dont_cuda = model_args.pop('dont_cuda', False)
|
60 |
+
global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False)
|
61 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
62 |
|
63 |
# get model
|
|
|
81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
82 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
83 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
84 |
+
if global_state.wait_with_hidden_states:
|
85 |
+
global_state.local_state.hidden_states = None
|
86 |
+
else:
|
87 |
+
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
88 |
+
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
89 |
+
|
90 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
91 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
92 |
progress_dummy_output = ''
|
93 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
|
|
94 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
95 |
|
96 |
|
97 |
@spaces.GPU
|
98 |
+
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
99 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
100 |
num_beams=1):
|
101 |
model = global_state.model
|
102 |
tokenizer = global_state.tokenizer
|
103 |
print(f'run {model}')
|
104 |
+
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
105 |
+
get_hidden_states(raw_original_prompt)
|
106 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
107 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
108 |
|
|
|
225 |
|
226 |
# event listeners
|
227 |
for i, btn in enumerate(tokens_container):
|
228 |
+
btn.click(partial(run_interpretation, i=i), [original_prompt_raw, interpretation_prompt,
|
229 |
num_tokens, do_sample, temperature,
|
230 |
top_k, top_p, repetition_penalty, length_penalty
|
231 |
], [progress_dummy, *interpretation_bubbles])
|