dejanseo commited on
Commit
c6e6058
·
verified ·
1 Parent(s): 65018a5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +90 -2
handler.py CHANGED
@@ -22,10 +22,22 @@ class EndpointHandler:
22
  self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
23
 
24
  self.model.eval()
 
 
 
 
25
 
26
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
  payload = data.get("inputs", data)
28
-
 
 
 
 
 
 
 
 
29
  query = payload["query"]
30
  candidates = payload["candidates"]
31
  results = []
@@ -38,7 +50,7 @@ class EndpointHandler:
38
  return_tensors="pt",
39
  padding="max_length",
40
  truncation=True,
41
- max_length=64
42
  ).to(self.device)
43
 
44
  out = self.model(**tokens)
@@ -51,3 +63,79 @@ class EndpointHandler:
51
  })
52
 
53
  return sorted(results, key=lambda x: x["score"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
23
 
24
  self.model.eval()
25
+
26
+ # Batch processing configuration
27
+ self.max_batch_size = 128 # Adjust based on GPU memory
28
+ self.max_length = 64
29
 
30
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
31
  payload = data.get("inputs", data)
32
+
33
+ # Check if this is batch processing (multiple queries) or single query
34
+ if "queries" in payload:
35
+ return self._process_batch(payload)
36
+ else:
37
+ return self._process_single(payload)
38
+
39
+ def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
40
+ """Original single query processing for backward compatibility"""
41
  query = payload["query"]
42
  candidates = payload["candidates"]
43
  results = []
 
50
  return_tensors="pt",
51
  padding="max_length",
52
  truncation=True,
53
+ max_length=self.max_length
54
  ).to(self.device)
55
 
56
  out = self.model(**tokens)
 
63
  })
64
 
65
  return sorted(results, key=lambda x: x["score"], reverse=True)
66
+
67
+ def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
68
+ """True batch processing for multiple queries"""
69
+ queries = payload["queries"]
70
+ candidates = payload["candidates"]
71
+
72
+ # Create all query-candidate combinations
73
+ all_texts = []
74
+ query_indices = []
75
+ candidate_indices = []
76
+
77
+ for q_idx, query in enumerate(queries):
78
+ for c_idx, candidate in enumerate(candidates):
79
+ text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}"
80
+ all_texts.append(text)
81
+ query_indices.append(q_idx)
82
+ candidate_indices.append(c_idx)
83
+
84
+ # Process in batches to avoid memory issues
85
+ all_scores = []
86
+ total_combinations = len(all_texts)
87
+
88
+ with torch.no_grad():
89
+ for i in range(0, total_combinations, self.max_batch_size):
90
+ batch_texts = all_texts[i:i + self.max_batch_size]
91
+
92
+ # Tokenize batch
93
+ tokens = self.tokenizer(
94
+ batch_texts,
95
+ return_tensors="pt",
96
+ padding="max_length",
97
+ truncation=True,
98
+ max_length=self.max_length
99
+ ).to(self.device)
100
+
101
+ # Single forward pass for entire batch
102
+ out = self.model(**tokens)
103
+ cls = out.last_hidden_state[:, 0, :]
104
+ scores = torch.sigmoid(self.classifier(cls)).squeeze()
105
+
106
+ # Handle single item case
107
+ if scores.dim() == 0:
108
+ scores = scores.unsqueeze(0)
109
+
110
+ all_scores.extend(scores.cpu().tolist())
111
+
112
+ # Reshape results back to query structure
113
+ results = []
114
+ for q_idx in range(len(queries)):
115
+ query_results = []
116
+ for c_idx, candidate in enumerate(candidates):
117
+ # Find the score for this query-candidate combination
118
+ combination_idx = q_idx * len(candidates) + c_idx
119
+ score = all_scores[combination_idx]
120
+
121
+ query_results.append({
122
+ "label": candidate["label"],
123
+ "description": candidate["description"],
124
+ "score": round(score, 4)
125
+ })
126
+
127
+ # Sort by score for this query
128
+ query_results.sort(key=lambda x: x["score"], reverse=True)
129
+ results.append(query_results)
130
+
131
+ return results
132
+
133
+ def get_batch_stats(self) -> Dict[str, Any]:
134
+ """Return batch processing statistics"""
135
+ return {
136
+ "max_batch_size": self.max_batch_size,
137
+ "max_length": self.max_length,
138
+ "device": str(self.device),
139
+ "model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown"
140
+ }
141
+