Tuyet3005's picture
Update streamlit_app.py/pages/SentimentDetection.py
febe13d
raw
history blame
2.87 kB
from os import path
import streamlit as st
# import pickle
# from tensorflow import keras
import tensorflow as tf
import torch
from torch import nn
from transformers import BertModel, BertTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = 'bert-base-cased'
# Build the Sentiment Classifier class
class SentimentClassifier(nn.Module):
# Constructor class
def __init__(self, n_classes):
super(SentimentClassifier, self).__init__()
self.bert = BertModel.from_pretrained(MODEL_NAME)
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
# Forward propagaion class
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=False
)
# Add a dropout layer
output = self.drop(pooled_output)
return self.out(output)
# from keras_preprocessing.sequence import pad_sequences
# def predict(ham_spam):
# model = load_model(r'test_HSmodel_r.h5')
# with open('tokenizer.pickle','rb') as handle:
# tokenizer = pickle.load(handle)
# tokenizer.fit_on_texts(ham_spam)
# x_1 = tokenizer.texts_to_sequences([ham_spam])
# x_1 = pad_sequences(x_1, maxlen=525)
# predictions = model.predict(x_1)[0][0]
# return predictions
MODEL_PATH = path.join(path.dirname(__file__), "bert_model.h5")
@st.cache_resource
def load_model_and_tokenizer():
model = SentimentClassifier(3)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
return model, BertTokenizer.from_pretrained('bert-base-cased')
def predict(content):
model, tokenizer = load_model_and_tokenizer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoded_review = tokenizer.encode_plus(
content,
max_length=160,
add_special_tokens=True,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors="pt",
)
input_ids = encoded_review["input_ids"].to(device)
attention_mask = encoded_review["attention_mask"].to(device)
output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)
class_names = ["negative", "neutral", "positive"]
return class_names[prediction]
def main():
# giving a title to our page
st.title("Sentiment detection")
contents = st.text_area("Please enter reviews/sentiment/setences/contents:")
prediction = ""
# Create a prediction button
if st.button("Analyze Spam Detection Result"):
prediction = predict(contents)
st.success(prediction)
if __name__ == "__main__":
main()