ArxivTopicPicker / inference.py
MaloYY's picture
Update inference.py
12d7b8a verified
import json
import torch
from torch import nn
from typing import List, Dict, Set
from pathlib import Path
from transformers import DistilBertTokenizer, DistilBertModel
class Nnet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.nnet = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Linear(256, 85)
)
def forward(self, x):
return self.nnet(x)
class ClassificationHead(nn.Module):
def __init__(self) -> None:
super().__init__()
self.nnet = Nnet()
ckpt = torch.load("resources/model.ckpt", map_location=torch.device('cpu'))
self.nnet.load_state_dict(ckpt['state_dict'], strict=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.nnet(x.unsqueeze(0))
class InferenceModel:
def __init__(self) -> None:
self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.head: nn.Module = ClassificationHead()
values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values())
values.remove('')
self.mapping: Dict = dict()
for i, val in enumerate(values):
self.mapping[i] = val
def topp(self, probs: torch.Tensor):
# sort probs
sorted_probs, sorted_inds = torch.sort(probs, descending=True)
# accumulate probs
accum = torch.cumsum(sorted_probs, dim=0)
# get index of the first element where cumsum reached 0.95
ind = torch.nonzero(accum > 0.95)[0]
return sorted_inds[:ind]
def get_lables(self, classes: torch.Tensor) -> List[str]:
output = ""
for cls in classes.numpy():
output += self.mapping[cls] + '\n'
return output
def inference(self, x: str) -> List[str]:
self.bert.eval()
self.head.eval()
with torch.no_grad():
# tokenize: str -> Tokens
encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True)
# get embedding: Tokens -> Embeddings -> MeanEmbedding
embeddings = self.bert(**encoded_input)
mean_embedding = embeddings[0].mean(dim=1)[0]
# get probs: MeanEmbedding -> Probs
probs = self.head(mean_embedding).softmax(-1)[0]
# get top_p classes: Probs -> 95% classes
topp_calsses = self.topp(probs)
print(probs)
# map classes to lables
return self.get_lables(topp_calsses)
# restart