Update app.py
Browse files
app.py
CHANGED
@@ -74,7 +74,7 @@ def initialize_gpu():
|
|
74 |
pass
|
75 |
|
76 |
|
77 |
-
def reset_model(model_name
|
78 |
# extract model info
|
79 |
model_args = deepcopy(model_info[model_name])
|
80 |
model_path = model_args.pop('model_path')
|
@@ -90,8 +90,6 @@ def reset_model(model_name, return_state=False):
|
|
90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
92 |
gc.collect()
|
93 |
-
if return_state:
|
94 |
-
return global_state
|
95 |
|
96 |
|
97 |
def get_hidden_states(raw_original_prompt):
|
@@ -145,11 +143,13 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
145 |
|
146 |
## main
|
147 |
torch.set_grad_enabled(False)
|
|
|
148 |
|
149 |
model_name = 'LLAMA2-7B'
|
150 |
-
|
151 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
152 |
tokens_container = []
|
|
|
153 |
for i in range(MAX_PROMPT_TOKENS):
|
154 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
155 |
tokens_container.append(btn)
|
|
|
74 |
pass
|
75 |
|
76 |
|
77 |
+
def reset_model(model_name):
|
78 |
# extract model info
|
79 |
model_args = deepcopy(model_info[model_name])
|
80 |
model_path = model_args.pop('model_path')
|
|
|
90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
92 |
gc.collect()
|
|
|
|
|
93 |
|
94 |
|
95 |
def get_hidden_states(raw_original_prompt):
|
|
|
143 |
|
144 |
## main
|
145 |
torch.set_grad_enabled(False)
|
146 |
+
global_state = GlobalState()
|
147 |
|
148 |
model_name = 'LLAMA2-7B'
|
149 |
+
reset_model(model_name)
|
150 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
151 |
tokens_container = []
|
152 |
+
|
153 |
for i in range(MAX_PROMPT_TOKENS):
|
154 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
155 |
tokens_container.append(btn)
|