dejanseo commited on
Commit
65018a5
·
verified ·
1 Parent(s): 6f1c63c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -15
handler.py CHANGED
@@ -12,30 +12,22 @@ class EndpointHandler:
12
  "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"]
13
  })
14
  self.model = AutoModel.from_pretrained(path).to(self.device)
15
-
16
  head_path = os.path.join(path, "classifier_head.json")
17
  with open(head_path, "r") as f:
18
  head = json.load(f)
19
-
20
  self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device)
21
  self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device)
22
  self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
23
-
24
  self.model.eval()
25
 
26
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
- """
28
- Expected input format:
29
- {
30
- "query": "how to sharpen kitchen knives",
31
- "candidates": [
32
- {"label": "Tool-Specific", "description": "..."},
33
- {"label": "Local Intent", "description": "..."}
34
- ]
35
- }
36
- """
37
- query = data["query"]
38
- candidates = data["candidates"]
39
  results = []
40
 
41
  with torch.no_grad():
 
12
  "additional_special_tokens": ["[QUERY]", "[LABEL_NAME]", "[LABEL_DESCRIPTION]"]
13
  })
14
  self.model = AutoModel.from_pretrained(path).to(self.device)
15
+
16
  head_path = os.path.join(path, "classifier_head.json")
17
  with open(head_path, "r") as f:
18
  head = json.load(f)
19
+
20
  self.classifier = torch.nn.Linear(self.model.config.hidden_size, 1).to(self.device)
21
  self.classifier.weight.data = torch.tensor(head["scorer_weight"]).to(self.device)
22
  self.classifier.bias.data = torch.tensor(head["scorer_bias"]).to(self.device)
23
+
24
  self.model.eval()
25
 
26
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
+ payload = data.get("inputs", data)
28
+
29
+ query = payload["query"]
30
+ candidates = payload["candidates"]
 
 
 
 
 
 
 
 
31
  results = []
32
 
33
  with torch.no_grad():