Ashed00 commited on
Commit
a9dac34
Β·
verified Β·
1 Parent(s): 649d707

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import shap
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ # Load model and tokenizer with caching
9
+ @st.cache_resource
10
+ def load_model():
11
+ tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
12
+ model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
13
+ return tokenizer, model
14
+
15
+ tokenizer, model = load_model()
16
+
17
+ # Define prediction function
18
+ def predict(texts):
19
+ processed_texts = []
20
+ for text in texts:
21
+ processed_texts.append(text if not isinstance(text, list)
22
+ else tokenizer.convert_tokens_to_string(text))
23
+
24
+ inputs = tokenizer(
25
+ processed_texts,
26
+ return_tensors="pt",
27
+ padding=True,
28
+ truncation=True,
29
+ max_length=512,
30
+ add_special_tokens=True
31
+ )
32
+
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+
36
+ return torch.nn.functional.softmax(outputs.logits, dim=-1).numpy()
37
+
38
+ # Initialize SHAP components
39
+ output_names = [model.config.id2label[i] for i in range(5)]
40
+ masker = shap.maskers.Text(tokenizer=tokenizer, mask_token=tokenizer.mask_token, collapse_mask_token=True)
41
+ explainer = shap.Explainer(predict, masker, output_names=output_names)
42
+
43
+ # Streamlit UI
44
+ st.title("🎯 BERT Sentiment Analysis with SHAP")
45
+ st.markdown("""
46
+ **How it works:**
47
+ 1. Enter text in the box below
48
+ 2. See predicted sentiment (1-5 stars)
49
+ 3. View confidence scores and word-level explanations
50
+ """)
51
+
52
+ text_input = st.text_area("Input Text", placeholder="Enter text to analyze...", height=100)
53
+
54
+ if st.button("Analyze Sentiment"):
55
+ if text_input.strip():
56
+ with st.spinner("Analyzing..."):
57
+ # Get predictions
58
+ probabilities = predict([text_input])[0]
59
+ predicted_class = np.argmax(probabilities)
60
+
61
+ # Display results
62
+ st.subheader("πŸ“Š Results")
63
+ cols = st.columns(2)
64
+ cols[0].metric("Predicted Sentiment", output_names[predicted_class])
65
+
66
+ with cols[1]:
67
+ st.markdown("**Confidence Scores**")
68
+ for i, (label, score) in enumerate(zip(output_names, probabilities)):
69
+ st.progress(score, text=f"{label}: {score:.1%}")
70
+
71
+ # Generate SHAP explanations
72
+ st.subheader("πŸ” Explanation")
73
+ st.markdown("""
74
+ **Word impacts**
75
+ Red β†’ Increases score | Blue β†’ Decreases score
76
+ Intensity shows magnitude of impact
77
+ """)
78
+
79
+ shap_values = explainer([text_input])
80
+
81
+ # Create tabs for each sentiment class
82
+ tabs = st.tabs(output_names)
83
+ for i, tab in enumerate(tabs):
84
+ with tab:
85
+ fig = shap.plots.text(shap_values[:, :, i], display=False)
86
+ st.pyplot(fig)
87
+ plt.close()
88
+ else:
89
+ st.warning("Please enter some text to analyze")
90
+
91
+ st.markdown("---")
92
+ st.markdown("Example texts to try:")
93
+ examples = st.columns(4)
94
+ example_texts = [
95
+ "This product exceeded all my expectations!",
96
+ "Terrible customer service experience.",
97
+ "The movie was okay, nothing special.",
98
+ "You are kinda cool"
99
+ ]
100
+
101
+ for col, text in zip(examples, example_texts):
102
+ with col:
103
+ if st.button(text, use_container_width=True):
104
+ st.session_state.last_input = text
105
+
106
+ if 'last_input' in st.session_state:
107
+ text_input = st.text_area("", value=st.session_state.last_input, height=100)