Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from torch import nn | |
from typing import List, Dict, Set | |
from pathlib import Path | |
from transformers import DistilBertTokenizer, DistilBertModel | |
class Nnet(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.nnet = nn.Sequential( | |
nn.Linear(768, 256), | |
nn.ReLU(), | |
nn.BatchNorm1d(256), | |
nn.Linear(256, 85) | |
) | |
def forward(self, x): | |
return self.nnet(x) | |
class ClassificationHead(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.nnet = Nnet() | |
ckpt = torch.load("resources/model.ckpt", map_location=torch.device('cpu')) | |
self.nnet.load_state_dict(ckpt['state_dict'], strict=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.nnet(x.unsqueeze(0)) | |
class InferenceModel: | |
def __init__(self) -> None: | |
self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
self.head: nn.Module = ClassificationHead() | |
values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values()) | |
values.remove('') | |
self.mapping: Dict = dict() | |
for i, val in enumerate(values): | |
self.mapping[i] = val | |
def topp(self, probs: torch.Tensor): | |
# sort probs | |
sorted_probs, sorted_inds = torch.sort(probs, descending=True) | |
# accumulate probs | |
accum = torch.cumsum(sorted_probs, dim=0) | |
# get index of the first element where cumsum reached 0.95 | |
ind = torch.nonzero(accum > 0.95)[0] | |
return sorted_inds[:ind] | |
def get_lables(self, classes: torch.Tensor) -> List[str]: | |
output = "" | |
for cls in classes.numpy(): | |
output += self.mapping[cls] + '\n' | |
return output | |
def inference(self, x: str) -> List[str]: | |
self.bert.eval() | |
self.head.eval() | |
with torch.no_grad(): | |
# tokenize: str -> Tokens | |
encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True) | |
# get embedding: Tokens -> Embeddings -> MeanEmbedding | |
embeddings = self.bert(**encoded_input) | |
mean_embedding = embeddings[0].mean(dim=1)[0] | |
# get probs: MeanEmbedding -> Probs | |
probs = self.head(mean_embedding).softmax(-1)[0] | |
# get top_p classes: Probs -> 95% classes | |
topp_calsses = self.topp(probs) | |
print(probs) | |
# map classes to lables | |
return self.get_lables(topp_calsses) | |
# restart | |