Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,6 +21,71 @@ EMOTION_LABELS = {
|
|
21 |
'LABEL_2': 'Neutral'
|
22 |
}
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def format_topics(topic_model, topic_counts):
|
25 |
"""Convert topic numbers to readable labels."""
|
26 |
formatted_topics = []
|
@@ -50,10 +115,26 @@ def format_emotions(emotion_counts):
|
|
50 |
})
|
51 |
return formatted_emotions
|
52 |
|
53 |
-
# [Previous functions remain the same until process_and_summarize]
|
54 |
-
|
55 |
def process_and_summarize(uploaded_file, top_n=50):
|
56 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# Initialize BERTopic with specific parameters
|
59 |
topic_model = BERTopic(
|
|
|
21 |
'LABEL_2': 'Neutral'
|
22 |
}
|
23 |
|
24 |
+
def chunk_text(text, max_length=512):
|
25 |
+
"""Split text into chunks of maximum token length."""
|
26 |
+
tokens = bert_tokenizer.encode(text, add_special_tokens=False)
|
27 |
+
chunks = []
|
28 |
+
|
29 |
+
for i in range(0, len(tokens), max_length - 2): # -2 to account for [CLS] and [SEP] tokens
|
30 |
+
chunk = tokens[i:i + max_length - 2]
|
31 |
+
# Add special tokens
|
32 |
+
chunk = [bert_tokenizer.cls_token_id] + chunk + [bert_tokenizer.sep_token_id]
|
33 |
+
chunks.append(chunk)
|
34 |
+
|
35 |
+
return chunks
|
36 |
+
|
37 |
+
def get_embedding_for_text(text):
|
38 |
+
"""Get embedding for a single text."""
|
39 |
+
chunks = chunk_text(text)
|
40 |
+
chunk_embeddings = []
|
41 |
+
|
42 |
+
for chunk in chunks:
|
43 |
+
# Convert to tensor and add batch dimension
|
44 |
+
input_ids = torch.tensor([chunk]).to(bert_model.device)
|
45 |
+
attention_mask = torch.ones_like(input_ids)
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
outputs = bert_model(input_ids, attention_mask=attention_mask)
|
49 |
+
|
50 |
+
# Get [CLS] token embedding for this chunk
|
51 |
+
chunk_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
52 |
+
chunk_embeddings.append(chunk_embedding[0])
|
53 |
+
|
54 |
+
# Average embeddings from all chunks
|
55 |
+
if chunk_embeddings:
|
56 |
+
return np.mean(chunk_embeddings, axis=0)
|
57 |
+
return np.zeros(bert_model.config.hidden_size) # fallback
|
58 |
+
|
59 |
+
def generate_embeddings(texts):
|
60 |
+
"""Generate embeddings for a list of texts."""
|
61 |
+
embeddings = []
|
62 |
+
|
63 |
+
for text in texts:
|
64 |
+
try:
|
65 |
+
embedding = get_embedding_for_text(text)
|
66 |
+
embeddings.append(embedding)
|
67 |
+
except Exception as e:
|
68 |
+
st.warning(f"Error processing text: {str(e)}")
|
69 |
+
# Add zero embedding as fallback
|
70 |
+
embeddings.append(np.zeros(bert_model.config.hidden_size))
|
71 |
+
|
72 |
+
return np.array(embeddings)
|
73 |
+
|
74 |
+
def classify_emotion(text):
|
75 |
+
"""Classify emotion for a single text."""
|
76 |
+
try:
|
77 |
+
chunks = chunk_text(text)
|
78 |
+
if not chunks:
|
79 |
+
return "unknown"
|
80 |
+
|
81 |
+
# Use first chunk for classification
|
82 |
+
chunk_text = bert_tokenizer.decode(chunks[0])
|
83 |
+
result = emotion_classifier(chunk_text)[0]
|
84 |
+
return result['label']
|
85 |
+
except Exception as e:
|
86 |
+
st.warning(f"Error in emotion classification: {str(e)}")
|
87 |
+
return "unknown"
|
88 |
+
|
89 |
def format_topics(topic_model, topic_counts):
|
90 |
"""Convert topic numbers to readable labels."""
|
91 |
formatted_topics = []
|
|
|
115 |
})
|
116 |
return formatted_emotions
|
117 |
|
|
|
|
|
118 |
def process_and_summarize(uploaded_file, top_n=50):
|
119 |
+
# Determine the file type
|
120 |
+
if uploaded_file.name.endswith(".csv"):
|
121 |
+
df = pd.read_csv(uploaded_file)
|
122 |
+
elif uploaded_file.name.endswith(".xlsx"):
|
123 |
+
df = pd.read_excel(uploaded_file)
|
124 |
+
else:
|
125 |
+
st.error("Unsupported file format.")
|
126 |
+
return None, None
|
127 |
+
|
128 |
+
# Validate required columns
|
129 |
+
required_columns = ['country', 'poem']
|
130 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
131 |
+
if missing_columns:
|
132 |
+
st.error(f"Missing columns: {', '.join(missing_columns)}")
|
133 |
+
return None, None
|
134 |
+
|
135 |
+
# Parse and preprocess the file
|
136 |
+
df['country'] = df['country'].str.strip()
|
137 |
+
df = df.dropna(subset=['country', 'poem'])
|
138 |
|
139 |
# Initialize BERTopic with specific parameters
|
140 |
topic_model = BERTopic(
|