from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM import torch import json import logging import time # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str = ""): logger.info(f"Initializing EndpointHandler with model path: {path}") try: self.tokenizer = AutoTokenizer.from_pretrained(path) logger.info("Tokenizer loaded successfully") self.model = AutoModelForCausalLM.from_pretrained( path, device_map="auto" ) logger.info(f"Model loaded successfully. Device map: {self.model.device}") self.model.eval() logger.info("Model set to evaluation mode") # Default generation parameters self.default_params = { "max_new_tokens": 1000, "temperature": 0.01, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.1, "do_sample": True } logger.info(f"Default generation parameters: {self.default_params}") except Exception as e: logger.error(f"Error during initialization: {str(e)}") raise def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: """Handle chat completion requests. Args: data: Dictionary containing: - messages: List of message dictionaries with 'role' and 'content' - generation_params: Optional dictionary of generation parameters Returns: List containing the generated response message """ try: logger.info("Processing new request") logger.info(f"Input data: {data}") input_messages = data.get("inputs", []) if not input_messages: logger.warning("No input messages provided") return [{"role": "assistant", "content": "No input messages provided"}] # Get generation parameters, use defaults for missing values gen_params = {**self.default_params, **data.get("generation_params", {})} logger.info(f"Generation parameters: {gen_params}") # Apply the chat template messages = [{"role": "user", "content": input_messages}] logger.info("Applying chat template") prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) logger.info(f"Generated chat prompt: {json.dumps(prompt)}") # Tokenize the prompt inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Generate response logger.info("Generating response") with torch.no_grad(): output_tokens = self.model.generate( **inputs, **gen_params ) # Decode the response logger.info("Decoding response") output_text = self.tokenizer.batch_decode(output_tokens)[0] # Extract only the assistant's response by finding the last assistant role block assistant_start = output_text.rfind("<|start_of_role|>assistant<|end_of_role|>") if assistant_start != -1: response = output_text[assistant_start + len("<|start_of_role|>assistant<|end_of_role|>"):].strip() # Remove any trailing end_of_text marker if "<|end_of_text|>" in response: response = response.split("<|end_of_text|>")[0].strip() # Check for function calling if "Calling function:" in response: # Split response into text and function call parts = response.split("Calling function:", 1) text_response = parts[0].strip() function_call = "Calling function:" + parts[1].strip() logger.info(f"Function call: {function_call}") logger.info(f"Text response: {text_response}") # Return both text and tool message return [ { "generated_text": text_response, "details": { "finish_reason": "stop", "generated_tokens": len(output_tokens[0]) } } ] else: response = output_text logger.info(f"Generated response: {json.dumps(response)}") return [{"generated_text": response, "details": {"finish_reason": "stop", "generated_tokens": len(output_tokens[0])}}] except Exception as e: logger.error(f"Error during generation: {str(e)}", exc_info=True) raise