Update app.py
Browse files
app.py
CHANGED
@@ -84,13 +84,15 @@ def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_trans
|
|
84 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
85 |
gc.collect()
|
86 |
if with_extra_components:
|
87 |
-
return ([welcome_message.format(model_name=model_name)]
|
88 |
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
|
89 |
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
|
90 |
+ [*extra_components])
|
|
|
|
|
91 |
|
92 |
|
93 |
-
def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
94 |
model, tokenizer = global_state.model, global_state.tokenizer
|
95 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
96 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
@@ -118,7 +120,7 @@ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
|
118 |
|
119 |
|
120 |
@spaces.GPU
|
121 |
-
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
122 |
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
123 |
num_beams=1):
|
124 |
model = global_state.model
|
@@ -186,9 +188,9 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
186 |
|
187 |
## main
|
188 |
torch.set_grad_enabled(False)
|
189 |
-
global_state = gr.State(GlobalState)
|
|
|
190 |
model_name = 'LLAMA2-7B'
|
191 |
-
reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
|
192 |
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
193 |
tokens_container = []
|
194 |
|
@@ -288,7 +290,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
288 |
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
289 |
|
290 |
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
291 |
-
model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
|
292 |
-
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
293 |
|
294 |
demo.launch()
|
|
|
84 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
85 |
gc.collect()
|
86 |
if with_extra_components:
|
87 |
+
return ([global_state, welcome_message.format(model_name=model_name)]
|
88 |
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
|
89 |
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
|
90 |
+ [*extra_components])
|
91 |
+
else:
|
92 |
+
return global_state
|
93 |
|
94 |
|
95 |
+
def get_hidden_states(global_state, raw_original_prompt, force_hidden_states=False):
|
96 |
model, tokenizer = global_state.model, global_state.tokenizer
|
97 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
98 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
|
120 |
|
121 |
|
122 |
@spaces.GPU
|
123 |
+
def run_interpretation(global_state, raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
124 |
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
125 |
num_beams=1):
|
126 |
model = global_state.model
|
|
|
188 |
|
189 |
## main
|
190 |
torch.set_grad_enabled(False)
|
191 |
+
global_state = gr.State(partial(reset_model, GlobalState(),
|
192 |
+
model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True))
|
193 |
model_name = 'LLAMA2-7B'
|
|
|
194 |
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
195 |
tokens_container = []
|
196 |
|
|
|
290 |
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
291 |
|
292 |
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
293 |
+
model_chooser.change(reset_model, [global_state, model_chooser, load_on_gpu, *extra_components],
|
294 |
+
[global_state, welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
295 |
|
296 |
demo.launch()
|