DexterSptizu commited on
Commit
67dc5f6
Β·
verified Β·
1 Parent(s): cfcdcd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -137
app.py CHANGED
@@ -1,180 +1,184 @@
1
  import streamlit as st
2
  import numpy as np
3
- from sentence_transformers import SentenceTransformer
4
- import plotly.express as px
5
  import plotly.graph_objects as go
6
- from sklearn.manifold import TSNE
7
- import torch
8
- from transformers import AutoTokenizer, AutoModel
9
  import pandas as pd
10
- from sentence_transformers import SentenceTransformer, util # Added util import
11
-
12
 
13
  # Page configuration
14
- st.set_page_config(layout="wide", page_title="Word & Sentence Embeddings Explorer")
15
 
 
16
  @st.cache_resource
17
- def load_models():
18
- sent_model = SentenceTransformer('all-MiniLM-L6-v2')
19
- word_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
20
- word_model = AutoModel.from_pretrained('bert-base-uncased')
21
- return sent_model, word_tokenizer, word_model
22
 
23
- sent_model, word_tokenizer, word_model = load_models()
 
 
 
24
 
25
- def get_word_embeddings(text):
26
- # Tokenize and get word embeddings
27
- tokens = word_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
28
- with torch.no_grad():
29
- outputs = word_model(**tokens)
30
- word_embeddings = outputs.last_hidden_state.squeeze(0)
31
 
32
- # Get original words from tokens
33
- words = word_tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])
34
 
35
- return words, word_embeddings
36
-
37
- def create_heatmap(embeddings, words):
38
- # Create heatmap of word embeddings
39
- fig = go.Figure(data=go.Heatmap(
40
- z=embeddings,
41
- x=[f'Dim {i+1}' for i in range(embeddings.shape[1])],
42
- y=words,
43
- colorscale='Viridis'
 
44
  ))
45
 
46
  fig.update_layout(
47
- title='Word Embeddings Heatmap',
48
- xaxis_title='Embedding Dimensions',
49
- yaxis_title='Words',
50
  height=400
51
  )
52
  return fig
53
 
54
- def create_word_scatter(embeddings, words):
55
- # Calculate appropriate perplexity value
56
- n_samples = len(embeddings)
57
- # Perplexity should be between 5 and 50, and less than n_samples
58
- perplexity = min(30, n_samples - 1) # Default is 30, but ensure it's less than n_samples
59
-
60
- # Reduce dimensions for visualization using t-SNE
61
- tsne = TSNE(
62
- n_components=2,
63
- perplexity=perplexity,
64
- random_state=42,
65
- init='random',
66
- learning_rate='auto'
67
- )
68
-
69
- # Perform t-SNE dimensionality reduction
70
- embeddings_2d = tsne.fit_transform(embeddings)
71
-
72
- # Create scatter plot
73
- fig = px.scatter(
74
- x=embeddings_2d[:, 0],
75
- y=embeddings_2d[:, 1],
76
- text=words,
77
- title=f'Word Embeddings in 2D Space (perplexity={perplexity})'
78
- )
79
 
80
- # Update layout for better visualization
81
- fig.update_traces(
82
- textposition='top center',
83
- mode='markers+text'
84
- )
85
  fig.update_layout(
86
- height=400,
87
- showlegend=False,
88
- xaxis_title="t-SNE dimension 1",
89
- yaxis_title="t-SNE dimension 2"
90
  )
91
-
92
  return fig
93
 
94
  def main():
95
- st.title("πŸ”€ Interactive Word & Sentence Embeddings Explorer")
96
 
97
- with st.expander("ℹ️ About this app", expanded=True):
98
  st.markdown("""
99
- This app helps you understand how words and sentences are represented in vector space:
100
- - **Word-level Analysis**: See how individual words are embedded
101
- - **Sentence-level Analysis**: Compare different sentences
102
- - **Interactive Visualizations**: Explore embeddings through various charts
 
 
103
  """)
104
 
105
- col1, col2 = st.columns([2, 1])
 
 
 
106
 
107
  with col1:
108
- text_input = st.text_area(
109
- "Enter your text",
110
- value="The quick brown fox jumps over the lazy dog",
111
- height=100,
112
- help="Enter any text to see its word and sentence embeddings"
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
115
  with col2:
116
- st.markdown("### Visualization Options")
117
- show_heatmap = st.checkbox("Show Heatmap", value=True)
118
- show_scatter = st.checkbox("Show Word Scatter", value=True)
119
- show_sentence = st.checkbox("Show Sentence Analysis", value=True)
 
 
120
 
121
- if text_input:
122
- # Get word-level embeddings
123
- words, word_embeddings = get_word_embeddings(text_input)
124
- word_embeddings = word_embeddings.numpy()
125
-
126
- # Remove special tokens
127
- mask = ~np.isin(words, ['[CLS]', '[SEP]', '[PAD]'])
128
- words = [w for i, w in enumerate(words) if mask[i]]
129
- word_embeddings = word_embeddings[mask]
130
-
131
- # Create visualizations
132
- if show_heatmap:
133
- st.plotly_chart(create_heatmap(word_embeddings, words), use_container_width=True)
134
-
135
- if show_scatter:
136
- st.plotly_chart(create_word_scatter(word_embeddings, words), use_container_width=True)
137
-
138
- if show_sentence:
139
- st.markdown("### Sentence-Level Analysis")
140
-
141
- # Get sentence embedding
142
- sentence_embedding = sent_model.encode(text_input)
143
 
144
- # Create sentence embedding visualization
145
- fig = go.Figure(data=go.Bar(
146
- x=list(range(len(sentence_embedding))),
147
- y=sentence_embedding,
148
- name='Sentence Embedding'
149
- ))
150
 
151
- fig.update_layout(
152
- title='Sentence Embedding Vector',
153
- xaxis_title='Dimension',
154
- yaxis_title='Value',
155
- height=300
156
- )
157
 
158
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
159
 
160
- # Add similarity comparison
161
- st.markdown("### Compare with Another Sentence")
162
- compare_text = st.text_area("Enter another sentence for comparison",
163
- value="A quick brown dog jumps over the lazy fox",
164
- height=100)
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- if compare_text:
167
- # Calculate similarity using the imported util
168
- similarity = util.pytorch_cos_sim(
169
- sent_model.encode(text_input, convert_to_tensor=True),
170
- sent_model.encode(compare_text, convert_to_tensor=True)
171
- ).item()
 
 
 
 
 
 
172
 
173
- st.metric(
174
- label="Semantic Similarity",
175
- value=f"{similarity:.2f}",
176
- help="1.0 = identical meaning, 0.0 = completely different"
177
- )
 
 
 
178
 
179
  if __name__ == "__main__":
180
  main()
 
1
  import streamlit as st
2
  import numpy as np
3
+ from sentence_transformers import SentenceTransformer, util
 
4
  import plotly.graph_objects as go
5
+ import plotly.express as px
6
+ from typing import List, Tuple
 
7
  import pandas as pd
 
 
8
 
9
  # Page configuration
10
+ st.set_page_config(layout="wide", page_title="🎯 Sentence Transformer Explorer")
11
 
12
+ # Load model
13
  @st.cache_resource
14
+ def load_model():
15
+ return SentenceTransformer('all-MiniLM-L6-v2')
16
+
17
+ model = load_model()
 
18
 
19
+ def get_embedding_and_similarity(sentences: List[str]) -> Tuple[np.ndarray, np.ndarray]:
20
+ embeddings = model.encode(sentences)
21
+ similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
22
+ return embeddings, similarity_matrix
23
 
24
+ def create_word_importance_visualization(sentence: str, embedding: np.ndarray):
25
+ # Calculate word-level contribution to the embedding
26
+ words = sentence.split()
27
+ word_embeddings = model.encode(words)
 
 
28
 
29
+ # Calculate each word's average contribution
30
+ word_importance = np.mean(np.abs(word_embeddings), axis=1)
31
 
32
+ # Create word importance visualization
33
+ fig = go.Figure()
34
+
35
+ # Add word bars
36
+ fig.add_trace(go.Bar(
37
+ x=words,
38
+ y=word_importance,
39
+ marker_color='rgb(158,202,225)',
40
+ text=np.round(word_importance, 3),
41
+ textposition='auto',
42
  ))
43
 
44
  fig.update_layout(
45
+ title="Word Importance in Embedding",
46
+ xaxis_title="Words",
47
+ yaxis_title="Average Contribution",
48
  height=400
49
  )
50
  return fig
51
 
52
+ def create_similarity_heatmap(sentences: List[str], similarity_matrix: np.ndarray):
53
+ fig = go.Figure(data=go.Heatmap(
54
+ z=similarity_matrix,
55
+ x=sentences,
56
+ y=sentences,
57
+ colorscale='RdBu',
58
+ text=np.round(similarity_matrix, 3),
59
+ texttemplate='%{text}',
60
+ textfont={"size": 10},
61
+ hoverongaps=False
62
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
64
  fig.update_layout(
65
+ title="Sentence Similarity Matrix",
66
+ height=400
 
 
67
  )
 
68
  return fig
69
 
70
  def main():
71
+ st.title("🎯 Interactive Sentence Transformer Explorer")
72
 
73
+ with st.expander("ℹ️ How it works", expanded=True):
74
  st.markdown("""
75
+ This interactive tool helps you understand how Sentence Transformers work:
76
+
77
+ 1. **Sentence Embedding**: Convert sentences into numerical vectors
78
+ 2. **Word Importance**: See how each word contributes to the final embedding
79
+ 3. **Similarity Analysis**: Compare how similar sentences are to each other
80
+ 4. **Interactive Examples**: Try different sentences and see the results
81
  """)
82
 
83
+ # Interactive sentence input
84
+ st.subheader("πŸ”€ Enter Your Sentences")
85
+
86
+ col1, col2 = st.columns(2)
87
 
88
  with col1:
89
+ # Example templates
90
+ example_templates = {
91
+ "Similar Meanings": [
92
+ "I love programming in Python",
93
+ "Coding with Python is my favorite",
94
+ "I enjoy developing software using Python"
95
+ ],
96
+ "Different Topics": [
97
+ "The cat sleeps on the mat",
98
+ "Python is a programming language",
99
+ "The weather is beautiful today"
100
+ ],
101
+ "Semantic Relations": [
102
+ "Paris is the capital of France",
103
+ "Berlin is the capital of Germany",
104
+ "London is the capital of England"
105
+ ]
106
+ }
107
 
108
+ selected_template = st.selectbox("Choose an example template:",
109
+ list(example_templates.keys()))
110
+
111
  with col2:
112
+ if st.button("Load Example"):
113
+ sentences = example_templates[selected_template]
114
+ else:
115
+ sentences = ["I love programming in Python",
116
+ "Coding with Python is my favorite",
117
+ "The weather is beautiful today"]
118
 
119
+ # Dynamic sentence input
120
+ num_sentences = st.slider("Number of sentences:", 2, 5, 3)
121
+ sentences = []
122
+
123
+ for i in range(num_sentences):
124
+ sentence = st.text_input(f"Sentence {i+1}",
125
+ value=sentences[i] if i < len(sentences) else "")
126
+ sentences.append(sentence)
127
+
128
+ if st.button("Analyze Sentences", type="primary"):
129
+ if all(sentences):
130
+ embeddings, similarity_matrix = get_embedding_and_similarity(sentences)
 
 
 
 
 
 
 
 
 
 
131
 
132
+ st.subheader("πŸ“Š Analysis Results")
 
 
 
 
 
133
 
134
+ # Create tabs for different visualizations
135
+ tab1, tab2, tab3 = st.tabs(["Word Importance", "Sentence Similarity", "Embedding Space"])
 
 
 
 
136
 
137
+ with tab1:
138
+ st.markdown("### πŸ” Word-Level Analysis")
139
+ for i, sentence in enumerate(sentences):
140
+ st.markdown(f"**Sentence {i+1}:** {sentence}")
141
+ fig = create_word_importance_visualization(sentence, embeddings[i])
142
+ st.plotly_chart(fig, use_container_width=True)
143
 
144
+ with tab2:
145
+ st.markdown("### 🀝 Sentence Similarity Analysis")
146
+ fig = create_similarity_heatmap(sentences, similarity_matrix)
147
+ st.plotly_chart(fig, use_container_width=True)
148
+
149
+ # Add similarity interpretation
150
+ st.markdown("#### πŸ’‘ Interpretation")
151
+ for i in range(len(sentences)):
152
+ for j in range(i+1, len(sentences)):
153
+ similarity = similarity_matrix[i][j]
154
+ interpretation = (
155
+ "Very similar" if similarity > 0.8
156
+ else "Moderately similar" if similarity > 0.5
157
+ else "Different"
158
+ )
159
+ st.write(f"Sentences {i+1} & {i+2}: {interpretation} ({similarity:.3f})")
160
 
161
+ with tab3:
162
+ st.markdown("### 🎯 Interactive Embedding Analysis")
163
+
164
+ # Create embedding statistics
165
+ embedding_stats = pd.DataFrame({
166
+ 'Sentence': sentences,
167
+ 'Embedding_Length': [np.linalg.norm(emb) for emb in embeddings],
168
+ 'Mean_Value': [np.mean(emb) for emb in embeddings],
169
+ 'Std_Dev': [np.std(emb) for emb in embeddings]
170
+ })
171
+
172
+ st.dataframe(embedding_stats)
173
 
174
+ st.markdown("""
175
+ #### πŸ“ Understanding Embeddings
176
+ - **Embedding Length**: Represents the magnitude of the vector
177
+ - **Mean Value**: Average of all dimensions
178
+ - **Standard Deviation**: Spread of values across dimensions
179
+ """)
180
+ else:
181
+ st.warning("Please enter all sentences before analyzing.")
182
 
183
  if __name__ == "__main__":
184
  main()