Vladislav Sokolovskii
Add handler and reqs
e97bce5
import os
from typing import Dict, List, Any
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import torch
from huggingface_hub import login
import os
class EndpointHandler:
def __init__(self, path=""):
# access_token = os.environ["HUGGINGFACE_TOKEN"]
# login(token=access_token)
# Load the model and tokenizer
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name = path, # Use the current directory path
max_seq_length = 2048,
dtype = None,
load_in_4bit = True,
)
FastLanguageModel.for_inference(self.model)
# Set up the chat template
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template="llama-3",
mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}
)
def __call__(self, data: Dict[str, Any]) -> List[str]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Extract parameters or use defaults
max_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.2)
top_p = parameters.get("top_p", 0.5)
system_message = parameters.get("system_message", "")
# Prepare messages
messages = [{"from": "human", "value": system_message}]
if isinstance(inputs, str):
messages.append({"from": "human", "value": inputs})
elif isinstance(inputs, list):
for msg in inputs:
role = "human" if msg["role"] == "user" else "gpt"
messages.append({"from": role, "value": msg["content"]})
# Tokenize input
tokenized_input = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to("cuda")
# Generate output
with torch.no_grad():
output = self.model.generate(
input_ids=tokenized_input,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
use_cache=True
)
# Decode and process the output
full_response = self.tokenizer.decode(output[0], skip_special_tokens=True)
response_lines = [line.strip() for line in full_response.split('\n') if line.strip()]
last_response = response_lines[-1] if response_lines else ""
return [last_response]