llama-test / handler.py
hdnh2006
handler.py added
d6d0889
raw
history blame
1.54 kB
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline
# get dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.tokenizer = LlamaTokenizer.from_pretrained(path)
model = LlamaForCausalLM.from_pretrained(path, load_in_4bit=True, device_map=0, torch_dtype=torch.float16)
# create inference pipeline
self.pipeline = pipeline("text-generation", model=model, tokenizer=self.tokenizer)
# def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
# inputs = data.pop("inputs", data)
# parameters = data.pop("parameters", None)
# # pass inputs with all kwargs in data
# if parameters is not None:
# prediction = self.pipeline(inputs, **parameters)
# else:
# prediction = self.pipeline(inputs)
# # postprocess the prediction
# return prediction
def __call__(self, message: str):
sequences = self.pipeline(
message,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
max_length=2048,
)
generated_text = sequences[0]['generated_text']
response = generated_text[len(message):] # Remove the prompt from the output
print("Chatbot:", response.strip())
response.strip()