File size: 2,702 Bytes
28f6ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5eae5da
28f6ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d7b8a
 
28f6ce1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
import torch
from torch import nn
from typing import List, Dict, Set
from pathlib import Path
from transformers import DistilBertTokenizer, DistilBertModel

class Nnet(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.nnet = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 85)
        )

    def forward(self, x):
        return self.nnet(x)
    

class ClassificationHead(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.nnet = Nnet()

        ckpt = torch.load("resources/model.ckpt", map_location=torch.device('cpu'))
        self.nnet.load_state_dict(ckpt['state_dict'], strict=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.nnet(x.unsqueeze(0))

class InferenceModel:
    def __init__(self) -> None:
        self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.head: nn.Module = ClassificationHead()
        
        values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values())
        values.remove('')
        self.mapping: Dict = dict()
        for i, val in enumerate(values):
            self.mapping[i] = val

    def topp(self, probs: torch.Tensor):
        # sort probs
        sorted_probs, sorted_inds = torch.sort(probs, descending=True)
        # accumulate probs
        accum = torch.cumsum(sorted_probs, dim=0)
        # get index of the first element where cumsum reached 0.95
        ind = torch.nonzero(accum > 0.95)[0]
        return sorted_inds[:ind]

    def get_lables(self, classes: torch.Tensor) -> List[str]:
        output = ""
        for cls in classes.numpy():
            output += self.mapping[cls] + '\n'
        
        return output

    def inference(self, x: str) -> List[str]:
        self.bert.eval()
        self.head.eval()
        with torch.no_grad():
            # tokenize: str -> Tokens
            encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True)
            # get embedding: Tokens -> Embeddings -> MeanEmbedding
            embeddings = self.bert(**encoded_input)
            mean_embedding = embeddings[0].mean(dim=1)[0]
            # get probs: MeanEmbedding -> Probs
            probs = self.head(mean_embedding).softmax(-1)[0]

        # get top_p classes: Probs -> 95% classes
        topp_calsses = self.topp(probs)
        print(probs)
        # map classes to lables
        return self.get_lables(topp_calsses)

# restart