Llama-2-7b-Chat-GPTQ / handler.py
gsaivinay's picture
Create handler.py
8324d4a
from transformers import AutoTokenizer, pipeline, logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
class EndpointHandler:
def __init__(self, path=""):
# load the model
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
use_safetensors=True,
trust_remote_code=False,
use_triton=False,
quantize_config=None
)
# create inference pipeline
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
# postprocess the prediction
return prediction