Last commit not found
import streamlit as st | |
from transformers import DistilBertTokenizer, DistilBertModel | |
import logging | |
logging.basicConfig(level=logging.ERROR) | |
import torch | |
MAX_LEN = 100 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True) | |
model_DB = DistilBERTClass() | |
loaded_model_path = './model_DB_1.pt' | |
model_DB.load_state_dict(torch.load(loaded_model_path, map_location=torch.device('cpu'))) | |
model_DB.to(device) | |
def sentiment_analysis_DB(input): | |
inputs = tokenizer.encode_plus( | |
input, | |
None, | |
add_special_tokens=True, | |
max_length=100, | |
pad_to_max_length=True, | |
return_token_type_ids=True | |
) | |
ids = torch.tensor([inputs['input_ids']]) # Convert to PyTorch tensor | |
mask = torch.tensor([inputs['attention_mask']]) # Convert to PyTorch tensor | |
token_type_ids = torch.tensor([inputs["token_type_ids"]]) # Convert to PyTorch tensor | |
# Assuming model_DB is a PyTorch model | |
output = model_DB(ids, mask, token_type_ids) | |
print('Raw output is ', output) | |
sigmoid_output = torch.sigmoid(output) | |
print('Sigmoid output is ', sigmoid_output) | |
# Assuming you want to use a threshold of 0.5 | |
result = 1 if sigmoid_output.item() > 0.5 else 0 | |
return result | |
# Streamlit app | |
st.title("Sentiment Analysis App") | |
# User input | |
user_input = st.text_area("Enter some text:") | |
# Button to trigger sentiment analysis | |
if st.button("Analyze Sentiment"): | |
# Perform sentiment analysis | |
result = sentiment_analysis_DB(user_input) | |
# Display result | |
if result == 1: | |
st.success("Positive sentiment detected!") | |
else: | |
st.error("Negative sentiment detected.") |