|
from transformers import BertTokenizer, BertForSequenceClassification,DistilBertTokenizer,DistilBertForSequenceClassification |
|
import torch |
|
import streamlit as st |
|
|
|
tokenizer = BertTokenizer.from_pretrained( |
|
"ashish-001/Bert-Amazon-review-sentiment-classifier") |
|
model = BertForSequenceClassification.from_pretrained( |
|
"ashish-001/Bert-Amazon-review-sentiment-classifier") |
|
|
|
distil_model = DistilBertForSequenceClassification.from_pretrained( |
|
"ashish-001/DistilBert-Amazon-review-sentiment-classifier") |
|
distil_tokenizer = DistilBertTokenizer.from_pretrained( |
|
"ashish-001/DistilBert-Amazon-review-sentiment-classifier") |
|
|
|
def classify_text(text): |
|
inputs = tokenizer( |
|
text, |
|
max_length=256, |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="pt" |
|
) |
|
output = model(**inputs) |
|
logits = output.logits |
|
probs = torch.nn.functional.sigmoid(logits) |
|
return probs |
|
|
|
def classify_text_distilbert(text): |
|
inputs=distil_tokenizer(text, return_tensors="pt") |
|
output = distil_model(**inputs) |
|
logits = output.logits |
|
probs = torch.nn.functional.sigmoid(logits) |
|
return probs |
|
|
|
|
|
st.title("Amazon Review Sentiment classifier") |
|
data = st.text_area("Enter or paste a review") |
|
if st.button('Predict using BERT'): |
|
prediction = classify_text(data) |
|
st.header( |
|
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}") |
|
|
|
if st.button('Predict Using DistilBERT'): |
|
prediction = classify_text_distilbert(data) |
|
st.header( |
|
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}") |
|
|
|
|