from typing import Dict, Any, List import torch from transformers import AutoTokenizer, AutoModel import os import json class EndpointHandler: def __init__(self, path: str = ""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.add_special_tokens({ "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"] }) self.model = AutoModel.from_pretrained(path).to(self.device) head_path = os.path.join(path, "classifier_head.json") with open(head_path, "r") as f: head = json.load(f) self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device) self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device) self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device) self.model.eval() # Batch processing configuration self.max_batch_size = 128 # Adjust based on GPU memory self.max_length = 64 def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: payload = data.get("inputs", data) # Check if this is batch processing (multiple queries) or single query if "queries" in payload: return self._process_batch(payload) else: return self._process_single(payload) def _process_single(self, payload: Dict[str, Any]) -> List[Dict[str, Any]]: """Original single query processing for backward compatibility""" query = payload["query"] candidates = payload["candidates"] results = [] with torch.no_grad(): for entry in candidates: text = f"[QUERY] {query} [LABEL_NAME] {entry['label']} [LABEL_DESCRIPTION] {entry['description']}" tokens = self.tokenizer( text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length ).to(self.device) out = self.model(**tokens) cls = out.last_hidden_state[:, 0, :] score = torch.sigmoid(self.classifier(cls)).item() results.append({ "label": entry["label"], "description": entry["description"], "score": round(score, 4) }) return sorted(results, key=lambda x: x["score"], reverse=True) def _process_batch(self, payload: Dict[str, Any]) -> List[List[Dict[str, Any]]]: """True batch processing for multiple queries""" queries = payload["queries"] candidates = payload["candidates"] # Create all query-candidate combinations all_texts = [] query_indices = [] candidate_indices = [] for q_idx, query in enumerate(queries): for c_idx, candidate in enumerate(candidates): text = f"[QUERY] {query} [LABEL_NAME] {candidate['label']} [LABEL_DESCRIPTION] {candidate['description']}" all_texts.append(text) query_indices.append(q_idx) candidate_indices.append(c_idx) # Process in batches to avoid memory issues all_scores = [] total_combinations = len(all_texts) with torch.no_grad(): for i in range(0, total_combinations, self.max_batch_size): batch_texts = all_texts[i:i + self.max_batch_size] # Tokenize batch tokens = self.tokenizer( batch_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length ).to(self.device) # Single forward pass for entire batch out = self.model(**tokens) cls = out.last_hidden_state[:, 0, :] scores = torch.sigmoid(self.classifier(cls)).squeeze() # Handle single item case if scores.dim() == 0: scores = scores.unsqueeze(0) all_scores.extend(scores.cpu().tolist()) # Reshape results back to query structure results = [] for q_idx in range(len(queries)): query_results = [] for c_idx, candidate in enumerate(candidates): # Find the score for this query-candidate combination combination_idx = q_idx * len(candidates) + c_idx score = all_scores[combination_idx] query_results.append({ "label": candidate["label"], "description": candidate["description"], "score": round(score, 4) }) # Sort by score for this query query_results.sort(key=lambda x: x["score"], reverse=True) results.append(query_results) return results def get_batch_stats(self) -> Dict[str, Any]: """Return batch processing statistics""" return { "max_batch_size": self.max_batch_size, "max_length": self.max_length, "device": str(self.device), "model_name": self.model.config.name_or_path if hasattr(self.model.config, 'name_or_path') else "unknown" }