File size: 892 Bytes
0fabff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from typing import Dict, Any, List

from transformers import pipeline


class EndpointHandler:
    def __init__(
            self,
            path: str,
    ) -> None:
        self.pipeline = pipeline(model=path, torch_dtype=torch.bfloat16, trust_remote_code=True, return_full_text=True)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str`)
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        # pass inputs with all kwargs in data
        if parameters is not None:
            prediction = self.pipeline(inputs, **parameters)
        else:
            prediction = self.pipeline(inputs)
        # postprocess the prediction
        return prediction