vikram-fresche's picture
handler_v15 (#15)
f7937a7 verified
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
logger.info(f"Initializing EndpointHandler with model path: {path}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(path)
logger.info("Tokenizer loaded successfully")
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto"
)
logger.info(f"Model loaded successfully. Device map: {self.model.device}")
self.model.eval()
logger.info("Model set to evaluation mode")
# Default generation parameters
self.default_params = {
"max_new_tokens": 1000,
"temperature": 0.01,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1,
"do_sample": True
}
logger.info(f"Default generation parameters: {self.default_params}")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""Handle chat completion requests.
Args:
data: Dictionary containing:
- messages: List of message dictionaries with 'role' and 'content'
- generation_params: Optional dictionary of generation parameters
Returns:
List containing the generated response message
"""
try:
logger.info("Processing new request")
logger.info(f"Input data: {data}")
input_messages = data.get("inputs", [])
if not input_messages:
logger.warning("No input messages provided")
return [{"role": "assistant", "content": "No input messages provided"}]
# Get generation parameters, use defaults for missing values
gen_params = {**self.default_params, **data.get("generation_params", {})}
logger.info(f"Generation parameters: {gen_params}")
# Apply the chat template
messages = [{"role": "user", "content": input_messages}]
logger.info("Applying chat template")
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.info(f"Generated chat prompt: {json.dumps(prompt)}")
# Tokenize the prompt
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate response
logger.info("Generating response")
with torch.no_grad():
output_tokens = self.model.generate(
**inputs,
**gen_params
)
# Decode the response
logger.info("Decoding response")
output_text = self.tokenizer.batch_decode(output_tokens)[0]
# Extract only the assistant's response by finding the last assistant role block
assistant_start = output_text.rfind("<|start_of_role|>assistant<|end_of_role|>")
if assistant_start != -1:
response = output_text[assistant_start + len("<|start_of_role|>assistant<|end_of_role|>"):].strip()
# Remove any trailing end_of_text marker
if "<|end_of_text|>" in response:
response = response.split("<|end_of_text|>")[0].strip()
# Check for function calling
if "Calling function:" in response:
# Split response into text and function call
parts = response.split("Calling function:", 1)
text_response = parts[0].strip()
function_call = "Calling function:" + parts[1].strip()
logger.info(f"Function call: {function_call}")
logger.info(f"Text response: {text_response}")
# Return both text and tool message
return [
{
"generated_text": text_response,
"details": {
"finish_reason": "stop",
"generated_tokens": len(output_tokens[0])
}
}
]
else:
response = output_text
logger.info(f"Generated response: {json.dumps(response)}")
return [{"generated_text": response, "details": {"finish_reason": "stop", "generated_tokens": len(output_tokens[0])}}]
except Exception as e:
logger.error(f"Error during generation: {str(e)}", exc_info=True)
raise