ashish-001 commited on
Commit
0c23025
·
verified ·
1 Parent(s): b7cb882

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import BertTokenizer, BertForSequenceClassification
2
  import torch
3
  import streamlit as st
4
 
@@ -7,6 +7,10 @@ tokenizer = BertTokenizer.from_pretrained(
7
  model = BertForSequenceClassification.from_pretrained(
8
  "ashish-001/Bert-Amazon-review-sentiment-classifier")
9
 
 
 
 
 
10
 
11
  def classify_text(text):
12
  inputs = tokenizer(
@@ -21,10 +25,23 @@ def classify_text(text):
21
  probs = torch.nn.functional.sigmoid(logits)
22
  return probs
23
 
 
 
 
 
 
 
 
24
 
25
  st.title("Amazon Review Sentiment classifier")
26
  data = st.text_area("Enter or paste a review")
27
- if st.button('Predict'):
28
  prediction = classify_text(data)
29
  st.header(
30
  f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForSequenceClassification,DistilBertTokenizer,DistilBertForSequenceClassification
2
  import torch
3
  import streamlit as st
4
 
 
7
  model = BertForSequenceClassification.from_pretrained(
8
  "ashish-001/Bert-Amazon-review-sentiment-classifier")
9
 
10
+ distil_model = DistilBertForSequenceClassification.from_pretrained(
11
+ "ashish-001/DistilBert-Amazon-review-sentiment-classifier")
12
+ distil_tokenizer = DistilBertTokenizer.from_pretrained(
13
+ "ashish-001/DistilBert-Amazon-review-sentiment-classifier")
14
 
15
  def classify_text(text):
16
  inputs = tokenizer(
 
25
  probs = torch.nn.functional.sigmoid(logits)
26
  return probs
27
 
28
+ def classify_text_distilbert(text):
29
+ inputs=distil_tokenizer(text, return_tensors="pt")
30
+ output = distil_model(**inputs)
31
+ logits = output.logits
32
+ probs = torch.nn.functional.sigmoid(logits)
33
+ return probs
34
+
35
 
36
  st.title("Amazon Review Sentiment classifier")
37
  data = st.text_area("Enter or paste a review")
38
+ if st.button('Predict using BERT'):
39
  prediction = classify_text(data)
40
  st.header(
41
  f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
42
+
43
+ if st.button('Predict Using DistilBERT'):
44
+ prediction = classify_text_distilbert(data)
45
+ st.header(
46
+ f"Negative Confidence: {prediction[0][0].item()}, Positive Confidence: {prediction[0][1].item()}")
47
+