dar-tau commited on
Commit
a34def0
·
verified ·
1 Parent(s): c46b218

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -52,7 +52,7 @@ suggested_interpretation_prompts = [
52
  def initialize_gpu():
53
  pass
54
 
55
- def reset_model(model_name, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
56
  # extract model info
57
  model_args = deepcopy(model_info[model_name])
58
  model_path = model_args.pop('model_path')
@@ -72,7 +72,7 @@ def reset_model(model_name, *extra_components, reset_sentence_transformer=False,
72
  if reset_sentence_transformer:
73
  global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
74
  gc.collect()
75
- if not dont_cuda:
76
  global_state.model.to('cuda')
77
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
78
  gc.collect()
@@ -175,7 +175,7 @@ torch.set_grad_enabled(False)
175
  global_state = GlobalState()
176
 
177
  model_name = 'LLAMA2-7B'
178
- reset_model(model_name, with_extra_components=False, reset_sentence_transformer=True)
179
  raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
180
  tokens_container = []
181
 
@@ -211,6 +211,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
211
 
212
  with gr.Group():
213
  model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
 
214
  welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
215
  with gr.Blocks() as demo_main:
216
  gr.Markdown('## The Prompt to Analyze')
@@ -274,7 +275,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
274
  raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
275
 
276
  extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
277
- model_chooser.change(reset_model, [model_chooser, *extra_components],
278
  [welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
279
 
280
  demo.launch()
 
52
  def initialize_gpu():
53
  pass
54
 
55
+ def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
56
  # extract model info
57
  model_args = deepcopy(model_info[model_name])
58
  model_path = model_args.pop('model_path')
 
72
  if reset_sentence_transformer:
73
  global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
74
  gc.collect()
75
+ if load_on_gpu and not dont_cuda:
76
  global_state.model.to('cuda')
77
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
78
  gc.collect()
 
175
  global_state = GlobalState()
176
 
177
  model_name = 'LLAMA2-7B'
178
+ reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
179
  raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
180
  tokens_container = []
181
 
 
211
 
212
  with gr.Group():
213
  model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
214
+ load_on_gpu = gr.Checkbox(label='Load on GPU', value=True)
215
  welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
216
  with gr.Blocks() as demo_main:
217
  gr.Markdown('## The Prompt to Analyze')
 
275
  raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
276
 
277
  extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
278
+ model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
279
  [welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
280
 
281
  demo.launch()