File size: 1,328 Bytes
3625af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
import torch



class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        guider_config = AutoConfig.from_pretrained(path)
        self.model =  AutoModelForSequenceClassification.from_pretrained(path, config=guider_config)
        self.tokenizer = AutoTokenizer.from_pretrained(path)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt.
        """
        # process input
        gen_outputs_no_input_decoded = data.pop("gen_outputs_no_input_decoded", data)

        # Guiding the model with his ranking,
        guider_inputs = self.tokenizer([gen_output_no_input_decoded for gen_output_no_input_decoded in gen_outputs_no_input_decoded],
                                        return_tensors='pt', padding=True, truncation=True)

        guider_outputs = self.model(**guider_inputs)

        # the slicing at the end [:,x]: x=0 for negative, x=1 for neutral, x=2 for positive
        guider_predictions = torch.nn.functional.softmax(guider_outputs.logits, dim=-1)[:, 0].tolist()
        return {"guider_predictions": guider_predictions}