Update app.py
Browse files
app.py
CHANGED
@@ -91,19 +91,20 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
91 |
|
92 |
## main
|
93 |
torch.set_grad_enabled(False)
|
94 |
-
model_name = '
|
95 |
|
96 |
# extract model info
|
97 |
model_args = deepcopy(model_info[model_name])
|
|
|
98 |
original_prompt_template = model_args.pop('original_prompt_template')
|
99 |
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
100 |
-
|
101 |
use_ctransformers = model_args.pop('ctransformers', False)
|
102 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
103 |
|
104 |
# get model
|
105 |
-
model = AutoModelClass.from_pretrained(
|
106 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
107 |
|
108 |
# demo
|
109 |
json_output = gr.JSON()
|
|
|
91 |
|
92 |
## main
|
93 |
torch.set_grad_enabled(False)
|
94 |
+
model_name = 'LLAMA2-7B'
|
95 |
|
96 |
# extract model info
|
97 |
model_args = deepcopy(model_info[model_name])
|
98 |
+
model_path = model_args.pop('model_path')
|
99 |
original_prompt_template = model_args.pop('original_prompt_template')
|
100 |
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
101 |
+
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
102 |
use_ctransformers = model_args.pop('ctransformers', False)
|
103 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
104 |
|
105 |
# get model
|
106 |
+
model = AutoModelClass.from_pretrained(model_path, **model_args)
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
108 |
|
109 |
# demo
|
110 |
json_output = gr.JSON()
|