|
from os import path |
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
import tensorflow as tf |
|
import torch |
|
from torch import nn |
|
from transformers import BertModel, BertTokenizer |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
MODEL_NAME = 'bert-base-cased' |
|
|
|
|
|
class SentimentClassifier(nn.Module): |
|
|
|
|
|
def __init__(self, n_classes): |
|
super(SentimentClassifier, self).__init__() |
|
self.bert = BertModel.from_pretrained(MODEL_NAME) |
|
self.drop = nn.Dropout(p=0.3) |
|
self.out = nn.Linear(self.bert.config.hidden_size, n_classes) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
_, pooled_output = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=False |
|
) |
|
|
|
output = self.drop(pooled_output) |
|
return self.out(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = path.join(path.dirname(__file__), "bert_model.h5") |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
model = SentimentClassifier(3) |
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model, BertTokenizer.from_pretrained('bert-base-cased') |
|
|
|
|
|
def predict(content): |
|
model, tokenizer = load_model_and_tokenizer() |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
encoded_review = tokenizer.encode_plus( |
|
content, |
|
max_length=160, |
|
add_special_tokens=True, |
|
return_token_type_ids=False, |
|
pad_to_max_length=True, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = encoded_review["input_ids"].to(device) |
|
attention_mask = encoded_review["attention_mask"].to(device) |
|
|
|
output = model(input_ids, attention_mask) |
|
_, prediction = torch.max(output, dim=1) |
|
|
|
class_names = ["negative", "neutral", "positive"] |
|
|
|
return class_names[prediction] |
|
|
|
|
|
def main(): |
|
|
|
st.title("Sentiment detection") |
|
contents = st.text_area("Please enter reviews/sentiment/setences/contents:") |
|
|
|
prediction = "" |
|
|
|
|
|
if st.button("Analyze Spam Detection Result"): |
|
prediction = predict(contents) |
|
st.success(prediction) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|