Update app.py
Browse files
app.py
CHANGED
@@ -52,7 +52,7 @@ suggested_interpretation_prompts = [
|
|
52 |
def initialize_gpu():
|
53 |
pass
|
54 |
|
55 |
-
def reset_model(model_name, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
|
56 |
# extract model info
|
57 |
model_args = deepcopy(model_info[model_name])
|
58 |
model_path = model_args.pop('model_path')
|
@@ -72,7 +72,7 @@ def reset_model(model_name, *extra_components, reset_sentence_transformer=False,
|
|
72 |
if reset_sentence_transformer:
|
73 |
global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
|
74 |
gc.collect()
|
75 |
-
if not dont_cuda:
|
76 |
global_state.model.to('cuda')
|
77 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
78 |
gc.collect()
|
@@ -175,7 +175,7 @@ torch.set_grad_enabled(False)
|
|
175 |
global_state = GlobalState()
|
176 |
|
177 |
model_name = 'LLAMA2-7B'
|
178 |
-
reset_model(model_name, with_extra_components=False, reset_sentence_transformer=True)
|
179 |
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
180 |
tokens_container = []
|
181 |
|
@@ -211,6 +211,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
211 |
|
212 |
with gr.Group():
|
213 |
model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
|
|
|
214 |
welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
|
215 |
with gr.Blocks() as demo_main:
|
216 |
gr.Markdown('## The Prompt to Analyze')
|
@@ -274,7 +275,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
274 |
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
275 |
|
276 |
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
277 |
-
model_chooser.change(reset_model, [model_chooser, *extra_components],
|
278 |
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
279 |
|
280 |
demo.launch()
|
|
|
52 |
def initialize_gpu():
|
53 |
pass
|
54 |
|
55 |
+
def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
|
56 |
# extract model info
|
57 |
model_args = deepcopy(model_info[model_name])
|
58 |
model_path = model_args.pop('model_path')
|
|
|
72 |
if reset_sentence_transformer:
|
73 |
global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
|
74 |
gc.collect()
|
75 |
+
if load_on_gpu and not dont_cuda:
|
76 |
global_state.model.to('cuda')
|
77 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
78 |
gc.collect()
|
|
|
175 |
global_state = GlobalState()
|
176 |
|
177 |
model_name = 'LLAMA2-7B'
|
178 |
+
reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
|
179 |
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
180 |
tokens_container = []
|
181 |
|
|
|
211 |
|
212 |
with gr.Group():
|
213 |
model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
|
214 |
+
load_on_gpu = gr.Checkbox(label='Load on GPU', value=True)
|
215 |
welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
|
216 |
with gr.Blocks() as demo_main:
|
217 |
gr.Markdown('## The Prompt to Analyze')
|
|
|
275 |
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
276 |
|
277 |
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
278 |
+
model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
|
279 |
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
280 |
|
281 |
demo.launch()
|