dar-tau commited on
Commit
f269195
·
verified ·
1 Parent(s): bf5b0c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -84,13 +84,15 @@ def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_trans
84
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
85
  gc.collect()
86
  if with_extra_components:
87
- return ([welcome_message.format(model_name=model_name)]
88
  + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
89
  + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
90
  + [*extra_components])
 
 
91
 
92
 
93
- def get_hidden_states(raw_original_prompt, force_hidden_states=False):
94
  model, tokenizer = global_state.model, global_state.tokenizer
95
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
96
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
@@ -118,7 +120,7 @@ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
118
 
119
 
120
  @spaces.GPU
121
- def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
122
  temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
123
  num_beams=1):
124
  model = global_state.model
@@ -186,9 +188,9 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
186
 
187
  ## main
188
  torch.set_grad_enabled(False)
189
- global_state = gr.State(GlobalState)
 
190
  model_name = 'LLAMA2-7B'
191
- reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
192
  raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
193
  tokens_container = []
194
 
@@ -288,7 +290,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
288
  raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
289
 
290
  extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
291
- model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
292
- [welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
293
 
294
  demo.launch()
 
84
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
85
  gc.collect()
86
  if with_extra_components:
87
+ return ([global_state, welcome_message.format(model_name=model_name)]
88
  + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
89
  + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
90
  + [*extra_components])
91
+ else:
92
+ return global_state
93
 
94
 
95
+ def get_hidden_states(global_state, raw_original_prompt, force_hidden_states=False):
96
  model, tokenizer = global_state.model, global_state.tokenizer
97
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
98
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
 
120
 
121
 
122
  @spaces.GPU
123
+ def run_interpretation(global_state, raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
124
  temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
125
  num_beams=1):
126
  model = global_state.model
 
188
 
189
  ## main
190
  torch.set_grad_enabled(False)
191
+ global_state = gr.State(partial(reset_model, GlobalState(),
192
+ model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True))
193
  model_name = 'LLAMA2-7B'
 
194
  raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
195
  tokens_container = []
196
 
 
290
  raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
291
 
292
  extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
293
+ model_chooser.change(reset_model, [global_state, model_chooser, load_on_gpu, *extra_components],
294
+ [global_state, welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
295
 
296
  demo.launch()