kcarnold commited on
Commit
6f71907
·
1 Parent(s): fa263f0

fix a missing import

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -20,6 +20,7 @@ def get_tokenizer(model_name):
20
 
21
  @st.cache_resource
22
  def get_model(model_name):
 
23
  from transformers import AutoModelForCausalLM
24
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
25
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
 
20
 
21
  @st.cache_resource
22
  def get_model(model_name):
23
+ import torch
24
  from transformers import AutoModelForCausalLM
25
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
26
  print(f"Loaded model, {model.num_parameters():,d} parameters.")