Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import BertTokenizer, BertForSequenceClassification
|
2 |
import torch
|
3 |
import streamlit as st
|
4 |
|
@@ -7,6 +7,10 @@ tokenizer = BertTokenizer.from_pretrained(
|
|
7 |
model = BertForSequenceClassification.from_pretrained(
|
8 |
"ashish-001/Bert-Amazon-review-sentiment-classifier")
|
9 |
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def classify_text(text):
|
12 |
inputs = tokenizer(
|
@@ -21,10 +25,23 @@ def classify_text(text):
|
|
21 |
probs = torch.nn.functional.sigmoid(logits)
|
22 |
return probs
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
st.title("Amazon Review Sentiment classifier")
|
26 |
data = st.text_area("Enter or paste a review")
|
27 |
-
if st.button('Predict'):
|
28 |
prediction = classify_text(data)
|
29 |
st.header(
|
30 |
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer, BertForSequenceClassification,DistilBertTokenizer,DistilBertForSequenceClassification
|
2 |
import torch
|
3 |
import streamlit as st
|
4 |
|
|
|
7 |
model = BertForSequenceClassification.from_pretrained(
|
8 |
"ashish-001/Bert-Amazon-review-sentiment-classifier")
|
9 |
|
10 |
+
distil_model = DistilBertForSequenceClassification.from_pretrained(
|
11 |
+
"ashish-001/DistilBert-Amazon-review-sentiment-classifier")
|
12 |
+
distil_tokenizer = DistilBertTokenizer.from_pretrained(
|
13 |
+
"ashish-001/DistilBert-Amazon-review-sentiment-classifier")
|
14 |
|
15 |
def classify_text(text):
|
16 |
inputs = tokenizer(
|
|
|
25 |
probs = torch.nn.functional.sigmoid(logits)
|
26 |
return probs
|
27 |
|
28 |
+
def classify_text_distilbert(text):
|
29 |
+
inputs=distil_tokenizer(text, return_tensors="pt")
|
30 |
+
output = distil_model(**inputs)
|
31 |
+
logits = output.logits
|
32 |
+
probs = torch.nn.functional.sigmoid(logits)
|
33 |
+
return probs
|
34 |
+
|
35 |
|
36 |
st.title("Amazon Review Sentiment classifier")
|
37 |
data = st.text_area("Enter or paste a review")
|
38 |
+
if st.button('Predict using BERT'):
|
39 |
prediction = classify_text(data)
|
40 |
st.header(
|
41 |
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
|
42 |
+
|
43 |
+
if st.button('Predict Using DistilBERT'):
|
44 |
+
prediction = classify_text_distilbert(data)
|
45 |
+
st.header(
|
46 |
+
f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
|
47 |
+
|