Update app.py
Browse files
app.py
CHANGED
@@ -41,7 +41,7 @@ suggested_interpretation_prompts = [
|
|
41 |
def initialize_gpu():
|
42 |
pass
|
43 |
|
44 |
-
def reset_model(model_name, *extra_components):
|
45 |
# extract model info
|
46 |
model_args = deepcopy(model_info[model_name])
|
47 |
model_path = model_args.pop('model_path')
|
@@ -58,7 +58,10 @@ def reset_model(model_name, *extra_components):
|
|
58 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
59 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
60 |
gc.collect()
|
61 |
-
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
def get_hidden_states(raw_original_prompt):
|
@@ -101,8 +104,8 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
101 |
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
|
102 |
|
103 |
# generate the interpretations
|
104 |
-
|
105 |
-
|
106 |
**generation_kwargs)
|
107 |
generation_texts = global_state.tokenizer.batch_decode(generated)
|
108 |
progress_dummy_output = ''
|
@@ -116,7 +119,7 @@ torch.set_grad_enabled(False)
|
|
116 |
global_state = GlobalState()
|
117 |
|
118 |
model_name = 'LLAMA2-7B'
|
119 |
-
reset_model(model_name)
|
120 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
121 |
tokens_container = []
|
122 |
|
|
|
41 |
def initialize_gpu():
|
42 |
pass
|
43 |
|
44 |
+
def reset_model(model_name, *extra_components, with_extra_components=True):
|
45 |
# extract model info
|
46 |
model_args = deepcopy(model_info[model_name])
|
47 |
model_path = model_args.pop('model_path')
|
|
|
58 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
59 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
60 |
gc.collect()
|
61 |
+
if with_extra_components:
|
62 |
+
for x in interpretation_bubbles:
|
63 |
+
x.visible = False
|
64 |
+
return extra_components
|
65 |
|
66 |
|
67 |
def get_hidden_states(raw_original_prompt):
|
|
|
104 |
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
|
105 |
|
106 |
# generate the interpretations
|
107 |
+
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors},
|
108 |
+
layers_format=global_state.layers_format, k=3,
|
109 |
**generation_kwargs)
|
110 |
generation_texts = global_state.tokenizer.batch_decode(generated)
|
111 |
progress_dummy_output = ''
|
|
|
119 |
global_state = GlobalState()
|
120 |
|
121 |
model_name = 'LLAMA2-7B'
|
122 |
+
reset_model(model_name, with_extra_components=False)
|
123 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
124 |
tokens_container = []
|
125 |
|