NiCEtmtm commited on
Commit
88e7970
1 Parent(s): bfeced2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -25,11 +25,11 @@ class ModelHandler:
25
 
26
  def load_model(self):
27
  # Load token as env var
28
- model_id = "NiCETmtm/Llama3_kw_gen_new"
29
  token = os.getenv("HF_API_TOKEN")
30
  # Load model & tokenizer
31
- self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, from_tf=True)
32
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
33
 
34
  def predict(self, inputs):
35
  tokens = self.tokenizer(inputs, return_tensors="pt")
 
25
 
26
  def load_model(self):
27
  # Load token as env var
28
+ model_id = "NiCETmtm/llama3_torch"
29
  token = os.getenv("HF_API_TOKEN")
30
  # Load model & tokenizer
31
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True, from_tf=True)
32
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True)
33
 
34
  def predict(self, inputs):
35
  tokens = self.tokenizer(inputs, return_tensors="pt")