File size: 1,653 Bytes
0c23025 dab0a62 0c23025 dab0a62 0c23025 dab0a62 0c23025 dab0a62 b7cb882 0c23025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import BertTokenizer, BertForSequenceClassification,DistilBertTokenizer,DistilBertForSequenceClassification
import torch
import streamlit as st
tokenizer = BertTokenizer.from_pretrained(
"ashish-001/Bert-Amazon-review-sentiment-classifier")
model = BertForSequenceClassification.from_pretrained(
"ashish-001/Bert-Amazon-review-sentiment-classifier")
distil_model = DistilBertForSequenceClassification.from_pretrained(
"ashish-001/DistilBert-Amazon-review-sentiment-classifier")
distil_tokenizer = DistilBertTokenizer.from_pretrained(
"ashish-001/DistilBert-Amazon-review-sentiment-classifier")
def classify_text(text):
inputs = tokenizer(
text,
max_length=256,
truncation=True,
padding="max_length",
return_tensors="pt"
)
output = model(**inputs)
logits = output.logits
probs = torch.nn.functional.sigmoid(logits)
return probs
def classify_text_distilbert(text):
inputs=distil_tokenizer(text, return_tensors="pt")
output = distil_model(**inputs)
logits = output.logits
probs = torch.nn.functional.sigmoid(logits)
return probs
st.title("Amazon Review Sentiment classifier")
data = st.text_area("Enter or paste a review")
if st.button('Predict using BERT'):
prediction = classify_text(data)
st.header(
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
if st.button('Predict Using DistilBERT'):
prediction = classify_text_distilbert(data)
st.header(
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
|