|
import torch |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline |
|
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.tokenizer = LlamaTokenizer.from_pretrained(path) |
|
model = LlamaForCausalLM.from_pretrained(path, load_in_4bit=True, device_map=0, torch_dtype=torch.float16) |
|
|
|
self.pipeline = pipeline("text-generation", model=model, tokenizer=self.tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, message: str): |
|
|
|
sequences = self.pipeline( |
|
message, |
|
do_sample=True, |
|
top_k=10, |
|
num_return_sequences=1, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
max_length=2048, |
|
) |
|
|
|
|
|
generated_text = sequences[0]['generated_text'] |
|
response = generated_text[len(message):] |
|
|
|
print("Chatbot:", response.strip()) |
|
|
|
response.strip() |