Update app.py
Browse files
app.py
CHANGED
@@ -32,10 +32,10 @@ model_info = {
|
|
32 |
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
|
33 |
), # , load_in_8bit=True
|
34 |
|
35 |
-
'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
|
41 |
original_prompt_template='<s>{prompt}',
|
@@ -75,7 +75,7 @@ def initialize_gpu():
|
|
75 |
pass
|
76 |
|
77 |
|
78 |
-
def reset_model(model_name
|
79 |
# extract model info
|
80 |
model_args = deepcopy(model_info[model_name])
|
81 |
model_path = model_args.pop('model_path')
|
@@ -91,10 +91,7 @@ def reset_model(model_name, return_extra_components=True):
|
|
91 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
92 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
93 |
gc.collect()
|
94 |
-
|
95 |
-
extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
|
96 |
-
original_prompt_raw]
|
97 |
-
return extra_components
|
98 |
|
99 |
|
100 |
def get_hidden_states(raw_original_prompt):
|
@@ -151,7 +148,7 @@ torch.set_grad_enabled(False)
|
|
151 |
global_state = GlobalState()
|
152 |
|
153 |
model_name = 'LLAMA2-7B'
|
154 |
-
reset_model(model_name
|
155 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
156 |
tokens_container = []
|
157 |
|
@@ -238,7 +235,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
238 |
# event listeners
|
239 |
extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
|
240 |
original_prompt_raw]
|
241 |
-
model_chooser.change(reset_model, [model_chooser], extra_components)
|
242 |
|
243 |
for i, btn in enumerate(tokens_container):
|
244 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
|
|
32 |
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
|
33 |
), # , load_in_8bit=True
|
34 |
|
35 |
+
# 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
|
36 |
+
# original_prompt_template='<bos>{prompt}',
|
37 |
+
# interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
|
38 |
+
# ),
|
39 |
|
40 |
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
|
41 |
original_prompt_template='<s>{prompt}',
|
|
|
75 |
pass
|
76 |
|
77 |
|
78 |
+
def reset_model(model_name):
|
79 |
# extract model info
|
80 |
model_args = deepcopy(model_info[model_name])
|
81 |
model_path = model_args.pop('model_path')
|
|
|
91 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
92 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
93 |
gc.collect()
|
94 |
+
return extra_components
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
def get_hidden_states(raw_original_prompt):
|
|
|
148 |
global_state = GlobalState()
|
149 |
|
150 |
model_name = 'LLAMA2-7B'
|
151 |
+
reset_model(model_name)
|
152 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
153 |
tokens_container = []
|
154 |
|
|
|
235 |
# event listeners
|
236 |
extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
|
237 |
original_prompt_raw]
|
238 |
+
model_chooser.change(reset_model, [model_chooser, extra_components], extra_components)
|
239 |
|
240 |
for i, btn in enumerate(tokens_container):
|
241 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|