Update handler.py
Browse files- handler.py +3 -3
handler.py
CHANGED
@@ -25,11 +25,11 @@ class ModelHandler:
|
|
25 |
|
26 |
def load_model(self):
|
27 |
# Load token as env var
|
28 |
-
model_id = "NiCETmtm/
|
29 |
token = os.getenv("HF_API_TOKEN")
|
30 |
# Load model & tokenizer
|
31 |
-
self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, from_tf=True)
|
32 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
|
33 |
|
34 |
def predict(self, inputs):
|
35 |
tokens = self.tokenizer(inputs, return_tensors="pt")
|
|
|
25 |
|
26 |
def load_model(self):
|
27 |
# Load token as env var
|
28 |
+
model_id = "NiCETmtm/llama3_torch"
|
29 |
token = os.getenv("HF_API_TOKEN")
|
30 |
# Load model & tokenizer
|
31 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True, from_tf=True)
|
32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True)
|
33 |
|
34 |
def predict(self, inputs):
|
35 |
tokens = self.tokenizer(inputs, return_tensors="pt")
|