Spaces:
Sleeping
Sleeping
File size: 2,702 Bytes
28f6ce1 5eae5da 28f6ce1 12d7b8a 28f6ce1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import json
import torch
from torch import nn
from typing import List, Dict, Set
from pathlib import Path
from transformers import DistilBertTokenizer, DistilBertModel
class Nnet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.nnet = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Linear(256, 85)
)
def forward(self, x):
return self.nnet(x)
class ClassificationHead(nn.Module):
def __init__(self) -> None:
super().__init__()
self.nnet = Nnet()
ckpt = torch.load("resources/model.ckpt", map_location=torch.device('cpu'))
self.nnet.load_state_dict(ckpt['state_dict'], strict=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.nnet(x.unsqueeze(0))
class InferenceModel:
def __init__(self) -> None:
self.tokenizer: DistilBertTokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.bert: DistilBertModel = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.head: nn.Module = ClassificationHead()
values: Set = set(json.loads(Path('resources/tag_mapping.json').read_text()).values())
values.remove('')
self.mapping: Dict = dict()
for i, val in enumerate(values):
self.mapping[i] = val
def topp(self, probs: torch.Tensor):
# sort probs
sorted_probs, sorted_inds = torch.sort(probs, descending=True)
# accumulate probs
accum = torch.cumsum(sorted_probs, dim=0)
# get index of the first element where cumsum reached 0.95
ind = torch.nonzero(accum > 0.95)[0]
return sorted_inds[:ind]
def get_lables(self, classes: torch.Tensor) -> List[str]:
output = ""
for cls in classes.numpy():
output += self.mapping[cls] + '\n'
return output
def inference(self, x: str) -> List[str]:
self.bert.eval()
self.head.eval()
with torch.no_grad():
# tokenize: str -> Tokens
encoded_input = self.tokenizer(x, return_tensors='pt', truncation=True)
# get embedding: Tokens -> Embeddings -> MeanEmbedding
embeddings = self.bert(**encoded_input)
mean_embedding = embeddings[0].mean(dim=1)[0]
# get probs: MeanEmbedding -> Probs
probs = self.head(mean_embedding).softmax(-1)[0]
# get top_p classes: Probs -> 95% classes
topp_calsses = self.topp(probs)
print(probs)
# map classes to lables
return self.get_lables(topp_calsses)
# restart
|