Update app.py
Browse files
app.py
CHANGED
@@ -45,25 +45,22 @@ suggested_interpretation_prompts = ["Before responding, let me repeat the messag
|
|
45 |
def initialize_gpu():
|
46 |
pass
|
47 |
|
48 |
-
def get_hidden_states(raw_original_prompt
|
49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
52 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
53 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
54 |
-
token_btns = []
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
65 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty, interpreted_vectors, num_beams=1):
|
66 |
-
|
67 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
68 |
|
69 |
# generation parameters
|
@@ -83,7 +80,7 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
83 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
84 |
|
85 |
# generate the interpretations
|
86 |
-
generated = interpretation_prompt.generate(model, {0:
|
87 |
generation_texts = tokenizer.batch_decode(generated)
|
88 |
return generation_texts
|
89 |
|
@@ -105,6 +102,8 @@ model = AutoModelClass.from_pretrained(model_name, **model_args)
|
|
105 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
|
106 |
|
107 |
# demo
|
|
|
|
|
108 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
109 |
with gr.Row():
|
110 |
with gr.Column(scale=5):
|
@@ -144,15 +143,15 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
144 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
145 |
|
146 |
with gr.Group('Output'):
|
|
|
147 |
with gr.Row():
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
), [original_prompt_raw], [*tokens_container])
|
158 |
demo.launch()
|
|
|
45 |
def initialize_gpu():
|
46 |
pass
|
47 |
|
48 |
+
def get_hidden_states(raw_original_prompt):
|
49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
52 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
53 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
54 |
+
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
55 |
+
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
56 |
+
return [hidden_state, *token_btns]
|
57 |
+
|
58 |
+
|
59 |
+
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
60 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
61 |
+
num_beams=1):
|
62 |
+
|
63 |
+
interpreted_vectors = global_state[:, i]
|
|
|
|
|
|
|
64 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
65 |
|
66 |
# generation parameters
|
|
|
80 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
81 |
|
82 |
# generate the interpretations
|
83 |
+
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
84 |
generation_texts = tokenizer.batch_decode(generated)
|
85 |
return generation_texts
|
86 |
|
|
|
102 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
|
103 |
|
104 |
# demo
|
105 |
+
global_state = gr.State([])
|
106 |
+
json_output = gr.JSON()
|
107 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
108 |
with gr.Row():
|
109 |
with gr.Column(scale=5):
|
|
|
143 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
144 |
|
145 |
with gr.Group('Output'):
|
146 |
+
tokens_container = []
|
147 |
with gr.Row():
|
148 |
+
for _ in range(MAX_PROMPT_TOKENS):
|
149 |
+
btn = gr.Button('', visible=False)
|
150 |
+
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
151 |
+
top_k, top_p, repetition_penalty, length_penalty
|
152 |
+
], [json_output])
|
153 |
+
tokens_container.append(btn)
|
154 |
+
json_output.render()
|
155 |
+
|
156 |
+
original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [global_state, *tokens_container])
|
|
|
157 |
demo.launch()
|