ArxivTopicPicker / inference.py
Pavel Malov
CPU only
5eae5da
raw
history blame
2.69 kB
import json
import torch
from torch import nn
from typing import List, Dict, Set
from pathlib import Path
from transformers import DistilBertTokenizer, DistilBertModel
class Nnet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.nnet = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Linear(256, 85)
)
def forward(self, x):
return self.nnet(x)
class ClassificationHead(nn.Module):
def __init__(self) -> None:
super().__init__()
self.nnet = Nnet()
ckpt = torch.load("resources/model.ckpt", map_location=torch.device('cpu'))
self.nnet.load_state_dict(ckpt['state_dict'], strict=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.nnet(x.unsqueeze(0))
class InferenceModel:
def __init__(self) -> None:
self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.head: nn.Module = ClassificationHead()
values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values())
values.remove('')
self.mapping: Dict = dict()
for i, val in enumerate(values):
self.mapping[i] = val
def topp(self, probs: torch.Tensor):
# sort probs
sorted_probs, sorted_inds = torch.sort(probs, descending=True)
# accumulate probs
accum = torch.cumsum(sorted_probs, dim=0)
# get index of the first element where cumsum reached 0.95
ind = torch.nonzero(accum > 0.95)[0]
return sorted_inds[:ind]
def get_lables(self, classes: torch.Tensor) -> List[str]:
output = ""
for cls in classes.numpy():
output += self.mapping[cls] + '\n'
return output
def inference(self, x: str) -> List[str]:
self.bert.eval()
self.head.eval()
with torch.no_grad():
# tokenize: str -> Tokens
encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True)
# get embedding: Tokens -> Embeddings -> MeanEmbedding
embeddings = self.bert(**encoded_input)
mean_embedding = embeddings[0].mean(dim=1)[0]
# get probs: MeanEmbedding -> Probs
probs = self.head(mean_embedding).softmax(-1)[0]
# get top_p classes: Probs -> 95% classes
topp_calsses = self.topp(probs)
print(probs)
# map classes to lables
return self.get_lables(topp_calsses)