Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
#
|
2 |
"""emotion-matcher.ipynb
|
3 |
|
4 |
Automatically generated by Colab.
|
@@ -22,52 +22,33 @@ splits = {
|
|
22 |
df = pd.read_parquet("hf://datasets/google-research-datasets/go_emotions/" + splits["train"])
|
23 |
|
24 |
# Preview the first few rows of the dataset
|
25 |
-
df.head()
|
26 |
|
27 |
-
|
28 |
-
We use the pandas library and the read_parquet() method to read the data into a table (DataFrame).
|
29 |
-
Then, we display the first few rows using df.head() to make sure the data was loaded correctly.
|
30 |
-
This is a perfect starting point for the next step – Exploratory Data Analysis (EDA).
|
31 |
-
"""
|
32 |
-
|
33 |
-
#Import necessary libraries
|
34 |
-
import pandas as pd
|
35 |
-
|
36 |
-
#View dataset shape
|
37 |
print("Dataset shape:", df.shape)
|
38 |
|
39 |
-
#View basic column information
|
40 |
print("\nColumn names:", df.columns.tolist())
|
41 |
|
42 |
-
#View detailed info
|
43 |
df.info()
|
44 |
|
45 |
-
|
46 |
-
This gives us an overview of what kind of data we’re dealing with (text, numbers, labels, etc.).
|
47 |
-
It helps us understand what preprocessing may be needed next.
|
48 |
-
"""
|
49 |
-
|
50 |
-
#Check for missing values
|
51 |
-
|
52 |
print("Missing values per column:")
|
53 |
print(df.isnull().sum())
|
54 |
|
55 |
-
#Check for duplicated rows (convert unhashable columns to string)
|
56 |
-
|
57 |
print("\nNumber of duplicated rows:")
|
58 |
print(df.astype(str).duplicated().sum())
|
59 |
|
60 |
-
#Check how many unique combinations of emotion labels exist
|
61 |
-
|
62 |
print("\nNumber of unique label combinations:")
|
63 |
print(df["labels"].apply(lambda x: tuple(x)).nunique())
|
64 |
|
65 |
-
#Compute text lengths in number of words
|
66 |
-
|
67 |
df["text_length"] = df["text"].apply(lambda x: len(x.split()))
|
68 |
|
69 |
-
#Plot histogram of text lengths
|
70 |
-
|
71 |
import matplotlib.pyplot as plt
|
72 |
|
73 |
plt.figure(figsize=(10,6))
|
@@ -78,16 +59,10 @@ plt.ylabel("Number of samples")
|
|
78 |
plt.grid(True)
|
79 |
plt.show()
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
"""
|
84 |
-
|
85 |
-
#Count how many emotion labels each text has
|
86 |
-
|
87 |
df["num_labels"] = df["labels"].apply(len)
|
88 |
|
89 |
-
#Plot distribution
|
90 |
-
|
91 |
plt.figure(figsize=(8,5))
|
92 |
df["num_labels"].value_counts().sort_index().plot(kind="bar")
|
93 |
plt.xlabel("Number of emotion labels")
|
@@ -95,8 +70,6 @@ plt.ylabel("Number of samples")
|
|
95 |
plt.title("Distribution of Emotion Labels per Sample")
|
96 |
plt.show()
|
97 |
|
98 |
-
"""Most samples are annotated with a single emotion label, and very few have multiple labels. This indicates that the dataset is mostly suitable for single-label classification tasks, although a multi-label approach could still capture additional nuance for rare cases."""
|
99 |
-
|
100 |
# Count frequency of each individual emotion label
|
101 |
from collections import Counter
|
102 |
|
@@ -105,7 +78,6 @@ all_labels = [label for labels in df["labels"] for label in labels]
|
|
105 |
label_counts = Counter(all_labels)
|
106 |
|
107 |
# Convert to DataFrame for plotting
|
108 |
-
import pandas as pd
|
109 |
emotion_freq = pd.DataFrame.from_dict(label_counts, orient='index', columns=['count'])
|
110 |
emotion_freq = emotion_freq.sort_values(by='count', ascending=False)
|
111 |
|
@@ -116,23 +88,12 @@ plt.xlabel("Emotion Label ID")
|
|
116 |
plt.ylabel("Number of Occurrences")
|
117 |
plt.show()
|
118 |
|
119 |
-
|
120 |
-
We observe a strong imbalance: some labels like 27 (likely “neutral”) dominate with over 14,000 occurrences,
|
121 |
-
while others like 16, 21, or 23 are very rare.
|
122 |
-
This highlights the need to consider class imbalance when training models.
|
123 |
-
"""
|
124 |
-
|
125 |
-
# Import necessary libraries
|
126 |
import numpy as np
|
127 |
-
import matplotlib.pyplot as plt
|
128 |
import seaborn as sns
|
129 |
|
130 |
-
# Create a binary matrix for emotions
|
131 |
-
# Get the maximum label ID from all label lists
|
132 |
num_labels = max([max(l.tolist()) if len(l) > 0 else 0 for l in df["labels"]]) + 1
|
133 |
-
|
134 |
emotion_matrix = np.zeros((len(df), num_labels), dtype=int)
|
135 |
-
|
136 |
for i, labels in enumerate(df["labels"]):
|
137 |
for label in labels:
|
138 |
emotion_matrix[i, label] = 1
|
@@ -145,25 +106,13 @@ plt.figure(figsize=(12, 10))
|
|
145 |
sns.heatmap(co_occurrence, cmap="Blues", linewidths=0.5)
|
146 |
plt.title("Emotion Co-occurrence Heatmap")
|
147 |
plt.xlabel("Emotion Label ID")
|
148 |
-
plt.ylabel("Emotion Label
|
149 |
plt.show()
|
150 |
|
151 |
-
"""This heatmap visualizes how frequently pairs of emotion labels co-occur within the same text. Darker shades indicate more frequent co-occurrences, helping identify emotions that often appear together."""
|
152 |
-
|
153 |
-
# View random samples of texts and their corresponding emotion labels
|
154 |
-
|
155 |
# Display 5 random rows
|
156 |
print("Sample text examples with emotion labels:")
|
157 |
print(df.sample(5)[["text", "labels"]])
|
158 |
|
159 |
-
"""This step is meant to get a qualitative sense of the dataset by inspecting real examples. It helps verify whether:
|
160 |
-
|
161 |
-
The texts are understandable and relevant.
|
162 |
-
The assigned emotion labels make sense.
|
163 |
-
There are any noisy, overly short, or unclear samples.
|
164 |
-
|
165 |
-
"""
|
166 |
-
|
167 |
# Define emotion label ID to name mapping manually (based on GoEmotions documentation)
|
168 |
id2label = [
|
169 |
'admiration', 'amusement', 'anger', 'annoyance', 'approval',
|
@@ -174,7 +123,6 @@ id2label = [
|
|
174 |
'neutral'
|
175 |
]
|
176 |
|
177 |
-
# Define a function to convert list of label IDs into label names
|
178 |
def decode_labels(label_ids):
|
179 |
return [id2label[i] for i in label_ids]
|
180 |
|
@@ -182,77 +130,44 @@ def decode_labels(label_ids):
|
|
182 |
print("Sample text examples with emotion label names:")
|
183 |
sample_df = df.sample(5)
|
184 |
sample_df["label_names"] = sample_df["labels"].apply(decode_labels)
|
185 |
-
|
186 |
-
|
187 |
-
"""Sample Texts with Emotion Labels
|
188 |
|
189 |
-
|
190 |
-
"""
|
191 |
-
|
192 |
-
# Import library for word cloud
|
193 |
from wordcloud import WordCloud
|
194 |
-
import matplotlib.pyplot as plt
|
195 |
|
196 |
-
# Combine all text data into one string
|
197 |
all_text = " ".join(df["text"])
|
198 |
-
|
199 |
-
# Generate word cloud
|
200 |
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(all_text)
|
201 |
|
202 |
-
# Plot word cloud
|
203 |
plt.figure(figsize=(12, 6))
|
204 |
plt.imshow(wordcloud, interpolation="bilinear")
|
205 |
plt.axis("off")
|
206 |
plt.title("Most Frequent Words in All Text Samples")
|
207 |
plt.show()
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
# Step: Text Preprocessing - clean the text data
|
212 |
import re
|
213 |
import string
|
214 |
|
215 |
-
# Define a function to clean each text entry
|
216 |
def clean_text(text):
|
217 |
-
# Lowercase
|
218 |
text = text.lower()
|
219 |
-
# Remove [NAME], [URL], and other placeholders
|
220 |
text = re.sub(r"\[.*?\]", "", text)
|
221 |
-
# Remove punctuation
|
222 |
text = text.translate(str.maketrans('', '', string.punctuation))
|
223 |
-
# Remove numbers
|
224 |
text = re.sub(r"\d+", "", text)
|
225 |
-
# Remove extra whitespaces
|
226 |
text = re.sub(r"\s+", " ", text).strip()
|
227 |
return text
|
228 |
|
229 |
-
# Apply cleaning to the text column
|
230 |
df["clean_text"] = df["text"].apply(clean_text)
|
231 |
|
232 |
-
# Preview cleaned text
|
233 |
print("Sample cleaned texts:")
|
234 |
-
|
235 |
-
|
236 |
-
"""
|
237 |
-
This preprocessing step standardizes text inputs by converting to lowercase, removing brackets like [NAME], punctuation, digits, and extra spaces — which helps downstream models focus on meaningful content.
|
238 |
-
"""
|
239 |
|
240 |
# Plot label distribution
|
241 |
-
|
242 |
-
# Flatten all label lists into a single list
|
243 |
-
all_labels = [label for sublist in df["labels"] for label in sublist]
|
244 |
-
|
245 |
-
# Count frequency of each label
|
246 |
-
from collections import Counter
|
247 |
-
label_counts = Counter(all_labels)
|
248 |
-
|
249 |
-
# Convert to DataFrame for plotting
|
250 |
label_df = pd.DataFrame.from_dict(label_counts, orient="index", columns=["count"])
|
251 |
label_df.index.name = "label_id"
|
252 |
label_df = label_df.sort_index()
|
253 |
label_df["label_name"] = label_df.index.map(lambda i: id2label[i])
|
254 |
|
255 |
-
# Plot bar chart
|
256 |
plt.figure(figsize=(14, 6))
|
257 |
sns.barplot(x="label_name", y="count", data=label_df)
|
258 |
plt.xticks(rotation=45, ha="right")
|
@@ -262,73 +177,27 @@ plt.ylabel("Frequency")
|
|
262 |
plt.tight_layout()
|
263 |
plt.show()
|
264 |
|
265 |
-
|
266 |
-
|
267 |
-
## 2. Embeddings
|
268 |
-
"""
|
269 |
-
|
270 |
-
# Import required libraries
|
271 |
from sentence_transformers import SentenceTransformer
|
272 |
import torch
|
273 |
|
274 |
-
# Choose a small and fast model for generating sentence embeddings
|
275 |
model = SentenceTransformer('all-MiniLM-L6-v2')
|
276 |
-
|
277 |
-
# Optional: move model to GPU if available
|
278 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
279 |
model = model.to(device)
|
280 |
|
281 |
-
# Subset the dataset to 2000 samples for efficiency
|
282 |
-
sample_df = df.sample(n=2000, random_state=42).reset_index(drop=True)
|
283 |
-
|
284 |
-
# Generate embeddings for the 'clean_text' column
|
285 |
-
# This might take 1-2 minutes
|
286 |
-
embeddings = model.encode(
|
287 |
-
sample_df['clean_text'].tolist(),
|
288 |
-
convert_to_tensor=True,
|
289 |
-
show_progress_bar=True,
|
290 |
-
device=device
|
291 |
-
)
|
292 |
-
|
293 |
-
# Store embeddings as a list inside the dataframe
|
294 |
-
sample_df['embedding'] = embeddings.cpu().numpy().tolist()
|
295 |
-
|
296 |
-
# Preview the result
|
297 |
-
sample_df[['clean_text', 'embedding']].head()
|
298 |
-
|
299 |
-
"""
|
300 |
-
|
301 |
-
We use the all-MiniLM-L6-v2 model from SentenceTransformers to convert each cleaned text into a dense vector representation, capturing semantic meaning for further clustering and visualization.
|
302 |
-
"""
|
303 |
-
|
304 |
-
from tqdm.notebook import tqdm
|
305 |
-
|
306 |
-
|
307 |
sample_df = df.sample(n=3000, random_state=42).reset_index(drop=True)
|
308 |
-
|
309 |
-
|
310 |
-
embeddings = model.encode(sample_df["clean_text"].tolist(), show_progress_bar=True)
|
311 |
-
|
312 |
sample_df["embedding"] = embeddings.tolist()
|
313 |
|
314 |
-
|
315 |
-
|
316 |
from sklearn.manifold import TSNE
|
317 |
-
import matplotlib.pyplot as plt
|
318 |
-
import numpy as np
|
319 |
|
320 |
-
# Convert list of embeddings to a NumPy array
|
321 |
X = np.array(sample_df["embedding"].tolist())
|
322 |
-
|
323 |
-
# Reduce the embedding dimensions to 2D using t-SNE
|
324 |
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
325 |
X_embedded = tsne.fit_transform(X)
|
326 |
-
|
327 |
-
# Add 2D coordinates to the dataframe
|
328 |
sample_df["x"] = X_embedded[:, 0]
|
329 |
sample_df["y"] = X_embedded[:, 1]
|
330 |
|
331 |
-
# Visualize the 2D embeddings using a scatter plot
|
332 |
plt.figure(figsize=(10, 6))
|
333 |
plt.scatter(sample_df["x"], sample_df["y"], alpha=0.5)
|
334 |
plt.title("t-SNE Projection of Text Embeddings")
|
@@ -336,18 +205,13 @@ plt.xlabel("Component 1")
|
|
336 |
plt.ylabel("Component 2")
|
337 |
plt.show()
|
338 |
|
339 |
-
|
340 |
-
|
341 |
from sklearn.cluster import KMeans
|
342 |
|
343 |
-
# Define the number of clusters (you can try different values like 5, 10, etc.)
|
344 |
num_clusters = 8
|
345 |
-
|
346 |
-
# Apply K-Means clustering to the embeddings
|
347 |
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
348 |
sample_df["cluster"] = kmeans.fit_predict(X)
|
349 |
|
350 |
-
# Visualize the clusters on the t-SNE projection
|
351 |
plt.figure(figsize=(10, 6))
|
352 |
scatter = plt.scatter(sample_df["x"], sample_df["y"], c=sample_df["cluster"], cmap='tab10', alpha=0.6)
|
353 |
plt.title(f"K-Means Clustering (k={num_clusters}) on t-SNE Projection")
|
@@ -356,46 +220,26 @@ plt.ylabel("Component 2")
|
|
356 |
plt.colorbar(scatter, label="Cluster")
|
357 |
plt.show()
|
358 |
|
359 |
-
|
360 |
-
|
361 |
-
## 3. Inputs & Outputs
|
362 |
-
"""
|
363 |
-
|
364 |
from sentence_transformers import util
|
365 |
-
import torch
|
366 |
|
367 |
-
# Ensure sample_df contains the 'embedding' column
|
368 |
EMBEDDINGS = torch.tensor(sample_df['embedding'].tolist(), device=device)
|
369 |
|
370 |
-
# Define the recommendation function
|
371 |
def recommend_similar_emotions(user_input):
|
372 |
if not user_input.strip():
|
373 |
return "Please enter some text."
|
374 |
-
|
375 |
-
# Encode the user input into an embedding
|
376 |
user_embedding = model.encode(user_input, convert_to_tensor=True, device=device)
|
377 |
-
|
378 |
-
# Compute cosine similarity between user input and all stored embeddings
|
379 |
similarities = util.cos_sim(user_embedding, EMBEDDINGS)[0]
|
380 |
top_indices = similarities.argsort(descending=True)[:5]
|
381 |
-
|
382 |
-
# Format the top 5 most similar results
|
383 |
results = []
|
384 |
for idx in top_indices:
|
385 |
row = sample_df.iloc[idx.item()]
|
386 |
results.append(f"{row['text']}\nEmotions: {row['labels']}")
|
387 |
-
|
388 |
return "\n\n".join(results)
|
389 |
|
390 |
-
|
391 |
-
|
392 |
-
"""Core recommendation logic for matching user input text to most similar texts in the dataset using sentence embeddings and cosine similarity.
|
393 |
-
Returns top 5 results with their associated emotion labels.
|
394 |
-
"""
|
395 |
-
|
396 |
import gradio as gr
|
397 |
|
398 |
-
# Create Gradio interface
|
399 |
demo = gr.Interface(
|
400 |
fn=recommend_similar_emotions,
|
401 |
inputs=gr.Textbox(lines=2, placeholder="Type your situation or feeling..."),
|
@@ -404,6 +248,4 @@ demo = gr.Interface(
|
|
404 |
description="Describe how you feel, and get similar examples with emotion labels."
|
405 |
)
|
406 |
|
407 |
-
demo.launch()
|
408 |
-
|
409 |
-
"""Set up the Gradio web app for entering text and viewing recommendations"""
|
|
|
1 |
+
# -- coding: utf-8 --
|
2 |
"""emotion-matcher.ipynb
|
3 |
|
4 |
Automatically generated by Colab.
|
|
|
22 |
df = pd.read_parquet("hf://datasets/google-research-datasets/go_emotions/" + splits["train"])
|
23 |
|
24 |
# Preview the first few rows of the dataset
|
25 |
+
print(df.head())
|
26 |
|
27 |
+
# View dataset shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
print("Dataset shape:", df.shape)
|
29 |
|
30 |
+
# View basic column information
|
31 |
print("\nColumn names:", df.columns.tolist())
|
32 |
|
33 |
+
# View detailed info
|
34 |
df.info()
|
35 |
|
36 |
+
# Check for missing values
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
print("Missing values per column:")
|
38 |
print(df.isnull().sum())
|
39 |
|
40 |
+
# Check for duplicated rows (convert unhashable columns to string)
|
|
|
41 |
print("\nNumber of duplicated rows:")
|
42 |
print(df.astype(str).duplicated().sum())
|
43 |
|
44 |
+
# Check how many unique combinations of emotion labels exist
|
|
|
45 |
print("\nNumber of unique label combinations:")
|
46 |
print(df["labels"].apply(lambda x: tuple(x)).nunique())
|
47 |
|
48 |
+
# Compute text lengths in number of words
|
|
|
49 |
df["text_length"] = df["text"].apply(lambda x: len(x.split()))
|
50 |
|
51 |
+
# Plot histogram of text lengths
|
|
|
52 |
import matplotlib.pyplot as plt
|
53 |
|
54 |
plt.figure(figsize=(10,6))
|
|
|
59 |
plt.grid(True)
|
60 |
plt.show()
|
61 |
|
62 |
+
# Count how many emotion labels each text has
|
|
|
|
|
|
|
|
|
|
|
63 |
df["num_labels"] = df["labels"].apply(len)
|
64 |
|
65 |
+
# Plot distribution
|
|
|
66 |
plt.figure(figsize=(8,5))
|
67 |
df["num_labels"].value_counts().sort_index().plot(kind="bar")
|
68 |
plt.xlabel("Number of emotion labels")
|
|
|
70 |
plt.title("Distribution of Emotion Labels per Sample")
|
71 |
plt.show()
|
72 |
|
|
|
|
|
73 |
# Count frequency of each individual emotion label
|
74 |
from collections import Counter
|
75 |
|
|
|
78 |
label_counts = Counter(all_labels)
|
79 |
|
80 |
# Convert to DataFrame for plotting
|
|
|
81 |
emotion_freq = pd.DataFrame.from_dict(label_counts, orient='index', columns=['count'])
|
82 |
emotion_freq = emotion_freq.sort_values(by='count', ascending=False)
|
83 |
|
|
|
88 |
plt.ylabel("Number of Occurrences")
|
89 |
plt.show()
|
90 |
|
91 |
+
# Create a binary matrix for emotions
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
import numpy as np
|
|
|
93 |
import seaborn as sns
|
94 |
|
|
|
|
|
95 |
num_labels = max([max(l.tolist()) if len(l) > 0 else 0 for l in df["labels"]]) + 1
|
|
|
96 |
emotion_matrix = np.zeros((len(df), num_labels), dtype=int)
|
|
|
97 |
for i, labels in enumerate(df["labels"]):
|
98 |
for label in labels:
|
99 |
emotion_matrix[i, label] = 1
|
|
|
106 |
sns.heatmap(co_occurrence, cmap="Blues", linewidths=0.5)
|
107 |
plt.title("Emotion Co-occurrence Heatmap")
|
108 |
plt.xlabel("Emotion Label ID")
|
109 |
+
plt.ylabel("Emotion Label ID")
|
110 |
plt.show()
|
111 |
|
|
|
|
|
|
|
|
|
112 |
# Display 5 random rows
|
113 |
print("Sample text examples with emotion labels:")
|
114 |
print(df.sample(5)[["text", "labels"]])
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# Define emotion label ID to name mapping manually (based on GoEmotions documentation)
|
117 |
id2label = [
|
118 |
'admiration', 'amusement', 'anger', 'annoyance', 'approval',
|
|
|
123 |
'neutral'
|
124 |
]
|
125 |
|
|
|
126 |
def decode_labels(label_ids):
|
127 |
return [id2label[i] for i in label_ids]
|
128 |
|
|
|
130 |
print("Sample text examples with emotion label names:")
|
131 |
sample_df = df.sample(5)
|
132 |
sample_df["label_names"] = sample_df["labels"].apply(decode_labels)
|
133 |
+
print(sample_df[["text", "label_names"]])
|
|
|
|
|
134 |
|
135 |
+
# Word cloud
|
|
|
|
|
|
|
136 |
from wordcloud import WordCloud
|
|
|
137 |
|
|
|
138 |
all_text = " ".join(df["text"])
|
|
|
|
|
139 |
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(all_text)
|
140 |
|
|
|
141 |
plt.figure(figsize=(12, 6))
|
142 |
plt.imshow(wordcloud, interpolation="bilinear")
|
143 |
plt.axis("off")
|
144 |
plt.title("Most Frequent Words in All Text Samples")
|
145 |
plt.show()
|
146 |
|
147 |
+
# Clean the text data
|
|
|
|
|
148 |
import re
|
149 |
import string
|
150 |
|
|
|
151 |
def clean_text(text):
|
|
|
152 |
text = text.lower()
|
|
|
153 |
text = re.sub(r"\[.*?\]", "", text)
|
|
|
154 |
text = text.translate(str.maketrans('', '', string.punctuation))
|
|
|
155 |
text = re.sub(r"\d+", "", text)
|
|
|
156 |
text = re.sub(r"\s+", " ", text).strip()
|
157 |
return text
|
158 |
|
|
|
159 |
df["clean_text"] = df["text"].apply(clean_text)
|
160 |
|
|
|
161 |
print("Sample cleaned texts:")
|
162 |
+
print(df[["text", "clean_text"]].sample(5))
|
|
|
|
|
|
|
|
|
163 |
|
164 |
# Plot label distribution
|
165 |
+
label_counts = Counter([label for sublist in df["labels"] for label in sublist])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
label_df = pd.DataFrame.from_dict(label_counts, orient="index", columns=["count"])
|
167 |
label_df.index.name = "label_id"
|
168 |
label_df = label_df.sort_index()
|
169 |
label_df["label_name"] = label_df.index.map(lambda i: id2label[i])
|
170 |
|
|
|
171 |
plt.figure(figsize=(14, 6))
|
172 |
sns.barplot(x="label_name", y="count", data=label_df)
|
173 |
plt.xticks(rotation=45, ha="right")
|
|
|
177 |
plt.tight_layout()
|
178 |
plt.show()
|
179 |
|
180 |
+
# Embeddings
|
|
|
|
|
|
|
|
|
|
|
181 |
from sentence_transformers import SentenceTransformer
|
182 |
import torch
|
183 |
|
|
|
184 |
model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
|
|
185 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
186 |
model = model.to(device)
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
sample_df = df.sample(n=3000, random_state=42).reset_index(drop=True)
|
189 |
+
embeddings = model.encode(sample_df["clean_text"].tolist(), show_progress_bar=True, device=device)
|
|
|
|
|
|
|
190 |
sample_df["embedding"] = embeddings.tolist()
|
191 |
|
192 |
+
# t-SNE visualization
|
|
|
193 |
from sklearn.manifold import TSNE
|
|
|
|
|
194 |
|
|
|
195 |
X = np.array(sample_df["embedding"].tolist())
|
|
|
|
|
196 |
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
197 |
X_embedded = tsne.fit_transform(X)
|
|
|
|
|
198 |
sample_df["x"] = X_embedded[:, 0]
|
199 |
sample_df["y"] = X_embedded[:, 1]
|
200 |
|
|
|
201 |
plt.figure(figsize=(10, 6))
|
202 |
plt.scatter(sample_df["x"], sample_df["y"], alpha=0.5)
|
203 |
plt.title("t-SNE Projection of Text Embeddings")
|
|
|
205 |
plt.ylabel("Component 2")
|
206 |
plt.show()
|
207 |
|
208 |
+
# KMeans Clustering
|
|
|
209 |
from sklearn.cluster import KMeans
|
210 |
|
|
|
211 |
num_clusters = 8
|
|
|
|
|
212 |
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
|
213 |
sample_df["cluster"] = kmeans.fit_predict(X)
|
214 |
|
|
|
215 |
plt.figure(figsize=(10, 6))
|
216 |
scatter = plt.scatter(sample_df["x"], sample_df["y"], c=sample_df["cluster"], cmap='tab10', alpha=0.6)
|
217 |
plt.title(f"K-Means Clustering (k={num_clusters}) on t-SNE Projection")
|
|
|
220 |
plt.colorbar(scatter, label="Cluster")
|
221 |
plt.show()
|
222 |
|
223 |
+
# Recommendation Function
|
|
|
|
|
|
|
|
|
224 |
from sentence_transformers import util
|
|
|
225 |
|
|
|
226 |
EMBEDDINGS = torch.tensor(sample_df['embedding'].tolist(), device=device)
|
227 |
|
|
|
228 |
def recommend_similar_emotions(user_input):
|
229 |
if not user_input.strip():
|
230 |
return "Please enter some text."
|
|
|
|
|
231 |
user_embedding = model.encode(user_input, convert_to_tensor=True, device=device)
|
|
|
|
|
232 |
similarities = util.cos_sim(user_embedding, EMBEDDINGS)[0]
|
233 |
top_indices = similarities.argsort(descending=True)[:5]
|
|
|
|
|
234 |
results = []
|
235 |
for idx in top_indices:
|
236 |
row = sample_df.iloc[idx.item()]
|
237 |
results.append(f"{row['text']}\nEmotions: {row['labels']}")
|
|
|
238 |
return "\n\n".join(results)
|
239 |
|
240 |
+
# Gradio App
|
|
|
|
|
|
|
|
|
|
|
241 |
import gradio as gr
|
242 |
|
|
|
243 |
demo = gr.Interface(
|
244 |
fn=recommend_similar_emotions,
|
245 |
inputs=gr.Textbox(lines=2, placeholder="Type your situation or feeling..."),
|
|
|
248 |
description="Describe how you feel, and get similar examples with emotion labels."
|
249 |
)
|
250 |
|
251 |
+
demo.launch()
|
|
|
|