llama3_torch / handler.py
NiCEtmtm's picture
Update handler.py
88e7970 verified
raw
history blame
1.37 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import subprocess
# Manually install bitsandbytes
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try:
import bitsandbytes
except ImportError:
install("bitsandbytes==0.39.1")
try:
import accelerate
except ImportError:
install("accelerate==0.20.0")
class ModelHandler:
def __init__(self):
self.model = None
self.tokenizer = None
def load_model(self):
# Load token as env var
model_id = "NiCETmtm/llama3_torch"
token = os.getenv("HF_API_TOKEN")
# Load model & tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True, from_tf=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, trust_remote_code=True)
def predict(self, inputs):
tokens = self.tokenizer(inputs, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(**tokens)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
model_handler = ModelHandler()
model_handler.load_model()
def inference(event, context):
inputs = event["data"]
outputs = model_handler.predict(inputs)
return {"predictions": outputs}