import json import torch from torch import nn from typing import List, Dict, Set from pathlib import Path from transformers import DistilBertTokenizer, DistilBertModel class Nnet(nn.Module): def __init__(self) -> None: super().__init__() self.nnet = nn.Sequential( nn.Linear(768, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Linear(256, 85) ) def forward(self, x): return self.nnet(x) class ClassificationHead(nn.Module): def __init__(self) -> None: super().__init__() self.nnet = Nnet() ckpt = torch.load("resources/model.ckpt", map_location=torch.device('cpu')) self.nnet.load_state_dict(ckpt['state_dict'], strict=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.nnet(x.unsqueeze(0)) class InferenceModel: def __init__(self) -> None: self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased") self.head: nn.Module = ClassificationHead() values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values()) values.remove('') self.mapping: Dict = dict() for i, val in enumerate(values): self.mapping[i] = val def topp(self, probs: torch.Tensor): # sort probs sorted_probs, sorted_inds = torch.sort(probs, descending=True) # accumulate probs accum = torch.cumsum(sorted_probs, dim=0) # get index of the first element where cumsum reached 0.95 ind = torch.nonzero(accum > 0.95)[0] return sorted_inds[:ind] def get_lables(self, classes: torch.Tensor) -> List[str]: output = "" for cls in classes.numpy(): output += self.mapping[cls] + '\n' return output def inference(self, x: str) -> List[str]: self.bert.eval() self.head.eval() with torch.no_grad(): # tokenize: str -> Tokens encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True) # get embedding: Tokens -> Embeddings -> MeanEmbedding embeddings = self.bert(**encoded_input) mean_embedding = embeddings[0].mean(dim=1)[0] # get probs: MeanEmbedding -> Probs probs = self.head(mean_embedding).softmax(-1)[0] # get top_p classes: Probs -> 95% classes topp_calsses = self.topp(probs) print(probs) # map classes to lables return self.get_lables(topp_calsses) # restart