import os import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from typing import Dict, List, Any # <-- ADD THIS LINE class EndpointHandler(): def __init__(self, model_id: str): """ Initializes the handler by loading the model and tokenizer. Args: model_id (str): The Hugging Face model ID (e.g., "MoritzLaurer/DeBERTa-v3-base-mnli") This is automatically passed by the Inference Endpoint infrastructure. """ print(f"Loading model '{model_id}'...") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForSequenceClassification.from_pretrained(model_id) # Move model to the determined device self.model.to(self.device) # Set model to evaluation mode for consistent inference self.model.eval() print("Model and tokenizer loaded successfully.") # --- Determine Label Order --- # Preferred: Dynamically get labels from model config try: # Sort by ID to ensure consistent order if dict isn't ordered sorted_labels = sorted(self.model.config.id2label.items()) self.label_names = [label for _, label in sorted_labels] print(f"Using label names from model config: {self.label_names}") # Basic validation for NLI task if len(self.label_names) != 3: print(f"Warning: Expected 3 labels for NLI, but model config has {len(self.label_names)}. Proceeding with model's labels.") if not any("entail" in l.lower() for l in self.label_names) or \ not any("neutral" in l.lower() for l in self.label_names) or \ not any("contra" in l.lower() for l in self.label_names): print(f"Warning: Model labels {self.label_names} might not match standard NLI labels ('entailment', 'neutral', 'contradiction').") except AttributeError: # Fallback: Use the explicitly requested labels if config is missing/malformed self.label_names = ["entailment", "neutral", "contradiction"] print(f"Warning: Could not read labels from model config. Falling back to default: {self.label_names}") print("Ensure this order matches the actual output order of the model!") print(f"Configured label order for output: {self.label_names}") # Corrected type hints in the signature below def __call__(self, data: Dict[str, Any]) -> Dict[str, Any] | List[Dict[str, Any]]: """ Handles inference requests. Args: data (Dict[str, Any]): The input data payload from the request. Expected keys: "premise" (str) and "hypothesis" (str). Can optionally be nested under "inputs". Returns: Dict[str, Any] | List[Dict[str, Any]]: A dictionary containing error info, or a list of dictionaries, each mapping a label name to its probability score. """ # --- Input Parsing --- inputs = data.get("inputs", data) # Allow for optional "inputs" nesting premise = inputs.get("premise") hypothesis = inputs.get("hypothesis") # Basic input validation if not premise or not isinstance(premise, str): return {"error": "Missing or invalid 'premise' key in input. Expected a string."} if not hypothesis or not isinstance(hypothesis, str): return {"error": "Missing or invalid 'hypothesis' key in input. Expected a string."} # --- Tokenization --- # Tokenize the premise-hypothesis pair try: tokenized_inputs = self.tokenizer( premise, hypothesis, return_tensors="pt", # Return PyTorch tensors truncation=True, # Truncate if longer than max length padding=True, # Pad to the longest sequence in the batch (or max_length) max_length=self.tokenizer.model_max_length # Use model's max length ) except Exception as e: print(f"Error during tokenization: {e}") return {"error": f"Failed to tokenize input: {e}"} # Move tokenized inputs to the same device as the model tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()} # --- Inference --- try: with torch.no_grad(): # Disable gradient calculations for efficiency outputs = self.model(**tokenized_inputs) logits = outputs.logits # Apply Softmax to get probabilities probabilities = torch.softmax(logits, dim=-1) # Move probabilities to CPU and convert to list # Squeeze or index [0] if processing single pairs (typical for endpoints) scores = probabilities.cpu().numpy()[0].tolist() # --- Format Output --- # Pair labels with their corresponding scores result = [{"label": label, "score": score} for label, score in zip(self.label_names, scores)] return result except Exception as e: print(f"Error during model inference: {e}") # Consider logging the full traceback here in a real deployment # import traceback # traceback.print_exc() return {"error": f"Model inference failed: {e}"}