dar-tau commited on
Commit
3e36699
1 Parent(s): f2d60cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -91,19 +91,20 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
91
 
92
  ## main
93
  torch.set_grad_enabled(False)
94
- model_name = 'meta-llama/Llama-2-7b-chat-hf' # 'mistralai/Mistral-7B-Instruct-v0.2' #
95
 
96
  # extract model info
97
  model_args = deepcopy(model_info[model_name])
 
98
  original_prompt_template = model_args.pop('original_prompt_template')
99
  interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
100
- tokenizer_name = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_name
101
  use_ctransformers = model_args.pop('ctransformers', False)
102
  AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
103
 
104
  # get model
105
- model = AutoModelClass.from_pretrained(model_name, **model_args)
106
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
107
 
108
  # demo
109
  json_output = gr.JSON()
 
91
 
92
  ## main
93
  torch.set_grad_enabled(False)
94
+ model_name = 'LLAMA2-7B'
95
 
96
  # extract model info
97
  model_args = deepcopy(model_info[model_name])
98
+ model_path = model_args.pop('model_path')
99
  original_prompt_template = model_args.pop('original_prompt_template')
100
  interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
101
+ tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
102
  use_ctransformers = model_args.pop('ctransformers', False)
103
  AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
104
 
105
  # get model
106
+ model = AutoModelClass.from_pretrained(model_path, **model_args)
107
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
108
 
109
  # demo
110
  json_output = gr.JSON()