File size: 1,725 Bytes
36e0e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Dict, List, Any
from setfit import SetFitModel


class EndpointHandler:
    def __init__(self, path=""):
        # load model
        self.model = SetFitModel.from_pretrained(path)
        # ag_news id to label mapping
        self.id2label = {0: "Absent", 1: "Present"}

    # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
    #     """
    #      data args:
    #           inputs (:obj: `str`)
    #     Return:
    #           A :obj:`list` | `dict`: will be serialized and returned
    #     """
    #     # get inputs
    #     inputs = data.pop("inputs", data)
    #     if isinstance(inputs, str):
    #         inputs = [inputs]

    #     # run normal prediction
    #     scores = self.model.predict_proba(inputs)[0]

    #     return [{"label": self.id2label[i], "score": score.item()} for i, score in enumerate(scores)]

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `List[str]`) - List of strings
        Return:
            A :obj:`list` of dicts: each dict contains 'label' and 'score' for each input string
        """
        # get inputs
        inputs = data.pop("inputs", data)
        if not isinstance(inputs, list):
            raise ValueError("Input must be a list of strings")

        # run normal prediction
        all_scores = self.model.predict_proba(inputs)  # This returns a list of score arrays

        # Format the results for each input string
        results = []
        for scores in all_scores:
            results.append([
                {"label": self.id2label[i], "score": score.item()} for i, score in enumerate(scores)
            ])

        return results