startificial commited on
Commit
c25d9aa
·
verified ·
1 Parent(s): b758e43

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +119 -0
  2. requirements.txt +4 -0
handler.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, model_id: str):
7
+ """
8
+ Initializes the handler by loading the model and tokenizer.
9
+
10
+ Args:
11
+ model_id (str): The Hugging Face model ID (e.g., "MoritzLaurer/DeBERTa-v3-base-mnli")
12
+ This is automatically passed by the Inference Endpoint infrastructure.
13
+ """
14
+ print(f"Loading model '{model_id}'...")
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Using device: {self.device}")
17
+
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
20
+
21
+ # Move model to the determined device
22
+ self.model.to(self.device)
23
+ # Set model to evaluation mode for consistent inference
24
+ self.model.eval()
25
+ print("Model and tokenizer loaded successfully.")
26
+
27
+ # --- Determine Label Order ---
28
+ # Preferred: Dynamically get labels from model config
29
+ try:
30
+ # Sort by ID to ensure consistent order if dict isn't ordered
31
+ sorted_labels = sorted(self.model.config.id2label.items())
32
+ self.label_names = [label for _, label in sorted_labels]
33
+ print(f"Using label names from model config: {self.label_names}")
34
+ # Basic validation for NLI task
35
+ if len(self.label_names) != 3:
36
+ print(f"Warning: Expected 3 labels for NLI, but model config has {len(self.label_names)}. Proceeding with model's labels.")
37
+ if not any("entail" in l.lower() for l in self.label_names) or \
38
+ not any("neutral" in l.lower() for l in self.label_names) or \
39
+ not any("contra" in l.lower() for l in self.label_names):
40
+ print(f"Warning: Model labels {self.label_names} might not match standard NLI labels ('entailment', 'neutral', 'contradiction').")
41
+
42
+ except AttributeError:
43
+ # Fallback: Use the explicitly requested labels if config is missing/malformed
44
+ self.label_names = ["entailment", "neutral", "contradiction"]
45
+ print(f"Warning: Could not read labels from model config. Falling back to default: {self.label_names}")
46
+ print("Ensure this order matches the actual output order of the model!")
47
+
48
+ print(f"Configured label order for output: {self.label_names}")
49
+
50
+
51
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any] | List[Dict[str, Any]]:
52
+ """
53
+ Handles inference requests.
54
+
55
+ Args:
56
+ data (Dict[str, Any]): The input data payload from the request.
57
+ Expected keys: "premise" (str) and "hypothesis" (str).
58
+ Can optionally be nested under "inputs".
59
+
60
+ Returns:
61
+ Dict[str, Any] | List[Dict[str, Any]]: A dictionary containing error info,
62
+ or a list of dictionaries, each mapping
63
+ a label name to its probability score.
64
+ """
65
+ # --- Input Parsing ---
66
+ inputs = data.get("inputs", data) # Allow for optional "inputs" nesting
67
+ premise = inputs.get("premise")
68
+ hypothesis = inputs.get("hypothesis")
69
+
70
+ # Basic input validation
71
+ if not premise or not isinstance(premise, str):
72
+ return {"error": "Missing or invalid 'premise' key in input. Expected a string."}
73
+ if not hypothesis or not isinstance(hypothesis, str):
74
+ return {"error": "Missing or invalid 'hypothesis' key in input. Expected a string."}
75
+
76
+ # --- Tokenization ---
77
+ # Tokenize the premise-hypothesis pair
78
+ try:
79
+ tokenized_inputs = self.tokenizer(
80
+ premise,
81
+ hypothesis,
82
+ return_tensors="pt", # Return PyTorch tensors
83
+ truncation=True, # Truncate if longer than max length
84
+ padding=True, # Pad to the longest sequence in the batch (or max_length)
85
+ max_length=self.tokenizer.model_max_length # Use model's max length
86
+ )
87
+ except Exception as e:
88
+ print(f"Error during tokenization: {e}")
89
+ return {"error": f"Failed to tokenize input: {e}"}
90
+
91
+
92
+ # Move tokenized inputs to the same device as the model
93
+ tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}
94
+
95
+ # --- Inference ---
96
+ try:
97
+ with torch.no_grad(): # Disable gradient calculations for efficiency
98
+ outputs = self.model(**tokenized_inputs)
99
+ logits = outputs.logits
100
+
101
+ # Apply Softmax to get probabilities
102
+ probabilities = torch.softmax(logits, dim=-1)
103
+
104
+ # Move probabilities to CPU and convert to list
105
+ # Squeeze or index [0] if processing single pairs (typical for endpoints)
106
+ scores = probabilities.cpu().numpy()[0].tolist()
107
+
108
+ # --- Format Output ---
109
+ # Pair labels with their corresponding scores
110
+ result = [{"label": label, "score": score} for label, score in zip(self.label_names, scores)]
111
+
112
+ return result
113
+
114
+ except Exception as e:
115
+ print(f"Error during model inference: {e}")
116
+ # Consider logging the full traceback here in a real deployment
117
+ # import traceback
118
+ # traceback.print_exc()
119
+ return {"error": f"Model inference failed: {e}"}
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers>=4.20.0 # Use a recent version
2
+ torch>=1.9.0 # Compatible Torch version
3
+ sentencepiece # Often required by tokenizers
4
+ protobuf # Sometimes needed as a dependency