from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM import torch import logging logger = logging.getLogger() logger.setLevel(logging.DEBUG) class EndpointHandler: def __init__(self, path=""): # Initialize model and tokenizer logger.info("Loading model and tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(".") logger.info("tokenizer loaded...") self.model = AutoModelForCausalLM.from_pretrained(".") logger.info("model loaded...") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data: JSON input with structure: { "inputs": "your text prompt here", "parameters": { "max_new_tokens": 50, "temperature": 0.7, "top_p": 0.9, "do_sample": true } } """ # Get input text and parameters inputs = data.pop("inputs", data) logger.info("inputs loaded...", inputs) parameters = data.pop("parameters", {}) # Default generation parameters generation_config = { "max_new_tokens": parameters.get("max_new_tokens", 50), "temperature": parameters.get("temperature", 0.7), "top_p": parameters.get("top_p", 0.9), "do_sample": parameters.get("do_sample", True), "pad_token_id": self.tokenizer.eos_token_id, "num_return_sequences": parameters.get("num_return_sequences", 1) } # Tokenize inputs = self.tokenizer( inputs, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(self.device) # Generate text with torch.no_grad(): generated_ids = self.model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, **generation_config ) # Decode and return generated text generated_texts = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True ) return { "generated_text": generated_texts[0], # Return first generation if multiple "all_generations": generated_texts # All generations if num_return_sequences > 1 }