|
from ast import literal_eval |
|
import torch |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification |
|
from transformers import BertForSequenceClassification, BertTokenizer, BertConfig |
|
from math import exp |
|
from . import label |
|
|
|
|
|
class Model(object): |
|
def __init__(self) -> None: |
|
self.model_name = "indolem/indobert-base-uncased" |
|
self.tokenizer = None |
|
self.model = None |
|
self.config = None |
|
|
|
def load_model(self, model_name: str = None, tasks: str = None): |
|
print(model_name) |
|
if tasks == "emotion": |
|
self.config = BertConfig.from_pretrained(model_name) |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(model_name) \ |
|
if tasks == "emotion" else \ |
|
AutoTokenizer.from_pretrained(model_name) |
|
|
|
if tasks == "emotion": |
|
self.model = BertForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
elif tasks == "ner": |
|
self.model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
else: |
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
def predict(self, sentences, tasks: str = None): |
|
encoded_input = self.tokenizer(sentences, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True) |
|
|
|
with torch.no_grad(): |
|
if tasks in ["emotion", "sentiment"]: |
|
outputs = self.model(**encoded_input) |
|
predicted_class = torch.argmax(outputs.logits, dim=1).item() |
|
logits = outputs.logits.numpy() |
|
probability = [exp(output)/(1+exp(output)) for output in logits[0]] |
|
else: |
|
recognizer = pipeline("token-classification", model=self.model, tokenizer=self.tokenizer) |
|
outputs = recognizer(sentences) |
|
|
|
if tasks in ["emotion", "sentiment"]: |
|
result = {"label": label[tasks][predicted_class], |
|
"score": probability[predicted_class]} |
|
elif tasks == "ner": |
|
result = [] |
|
for output in outputs: |
|
result.append( |
|
{ |
|
"entity": output["entity"], |
|
"score": float(output["score"]), |
|
"index": int(output["index"]), |
|
"word": output["word"], |
|
"start": int(output["start"]), |
|
"end": int(output["end"]) |
|
} |
|
) |
|
else: |
|
result = "" |
|
|
|
return result |