Bashir Rastegarpanah commited on
Commit
9868e73
·
1 Parent(s): 19e7040

add custom handler

Browse files
Files changed (1) hide show
  1. handler.py +41 -0
handler.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer
3
+
4
+
5
+ class EndpointHandler:
6
+ def __init__(self,
7
+ model,
8
+ # path=""
9
+ ):
10
+ # Preload all the elements you are going to need at inference.
11
+ # pseudo:
12
+ # self.model= load_model(path)
13
+ self.model = model
14
+ self.tokenizer = AutoTokenizer.from_pretrained("roberta-large", padding_side=padding_side)
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
+
18
+ input_dict = data.pop("inputs", data)
19
+
20
+ self.model.eval()
21
+ input = self.tokenizer(input_dict['answer'],
22
+ input_dict['source'],
23
+ truncation=True,
24
+ max_length=None,
25
+ return_tensors="pt"
26
+ )
27
+ # input.to(device)
28
+ # with torch.no_grad():
29
+ # output = model(**input)
30
+ output = model(**input)
31
+
32
+ prediction = output.logits.argmax(dim=-1)
33
+
34
+ #smax = nn.Softmax(dim=1)
35
+ #score = smax(output.logits)
36
+
37
+ return [{
38
+ "label": prediction.item(),
39
+ #"score": score[0][0].item()
40
+ }
41
+ ]