kambris commited on
Commit
0156b72
·
verified ·
1 Parent(s): 79bbe0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -82
app.py CHANGED
@@ -21,40 +21,52 @@ EMOTION_LABELS = {
21
  'LABEL_2': 'Neutral'
22
  }
23
 
24
- def chunk_text(text, max_length=512):
25
- """Split text into chunks of maximum token length."""
 
 
 
 
26
  tokens = bert_tokenizer.encode(text, add_special_tokens=False)
27
  chunks = []
 
28
 
29
- for i in range(0, len(tokens), max_length - 2): # -2 to account for [CLS] and [SEP] tokens
30
- chunk = tokens[i:i + max_length - 2]
 
31
  # Add special tokens
32
- chunk = [bert_tokenizer.cls_token_id] + chunk + [bert_tokenizer.sep_token_id]
33
- chunks.append(chunk)
 
 
34
 
35
- return chunks
36
 
37
  def get_embedding_for_text(text):
38
- """Get embedding for a single text."""
39
- chunks = chunk_text(text)
40
  chunk_embeddings = []
41
 
42
- for chunk in chunks:
43
- # Convert to tensor and add batch dimension
44
- input_ids = torch.tensor([chunk]).to(bert_model.device)
45
- attention_mask = torch.ones_like(input_ids)
 
 
 
 
46
 
47
  with torch.no_grad():
48
- outputs = bert_model(input_ids, attention_mask=attention_mask)
49
 
50
- # Get [CLS] token embedding for this chunk
51
- chunk_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
52
- chunk_embeddings.append(chunk_embedding[0])
53
 
54
  # Average embeddings from all chunks
55
  if chunk_embeddings:
56
  return np.mean(chunk_embeddings, axis=0)
57
- return np.zeros(bert_model.config.hidden_size) # fallback
58
 
59
  def generate_embeddings(texts):
60
  """Generate embeddings for a list of texts."""
@@ -66,22 +78,28 @@ def generate_embeddings(texts):
66
  embeddings.append(embedding)
67
  except Exception as e:
68
  st.warning(f"Error processing text: {str(e)}")
69
- # Add zero embedding as fallback
70
  embeddings.append(np.zeros(bert_model.config.hidden_size))
71
 
72
  return np.array(embeddings)
73
 
74
  def classify_emotion(text):
75
- """Classify emotion for a single text."""
 
 
76
  try:
77
- chunks = chunk_text(text)
78
- if not chunks:
79
- return "unknown"
 
 
 
 
 
 
 
 
 
80
 
81
- # Use first chunk for classification
82
- chunk_text = bert_tokenizer.decode(chunks[0])
83
- result = emotion_classifier(chunk_text)[0]
84
- return result['label']
85
  except Exception as e:
86
  st.warning(f"Error in emotion classification: {str(e)}")
87
  return "unknown"
@@ -93,9 +111,7 @@ def format_topics(topic_model, topic_counts):
93
  if topic_num == -1:
94
  topic_label = "Miscellaneous"
95
  else:
96
- # Get the top words for this topic
97
  words = topic_model.get_topic(topic_num)
98
- # Take the top 3 words to form a topic label
99
  topic_label = " | ".join([word for word, _ in words[:3]])
100
 
101
  formatted_topics.append({
@@ -136,10 +152,11 @@ def process_and_summarize(uploaded_file, top_n=50):
136
  df['country'] = df['country'].str.strip()
137
  df = df.dropna(subset=['country', 'poem'])
138
 
139
- # Initialize BERTopic with specific parameters
140
  topic_model = BERTopic(
141
  language="arabic",
142
  calculate_probabilities=True,
 
143
  verbose=True
144
  )
145
 
@@ -151,26 +168,23 @@ def process_and_summarize(uploaded_file, top_n=50):
151
  texts = group['poem'].dropna().tolist()
152
  batch_size = 10
153
  all_emotions = []
154
- all_embeddings = []
155
 
 
 
 
 
 
156
  for i in range(0, len(texts), batch_size):
157
  batch_texts = texts[i:i + batch_size]
158
-
159
- st.info(f"Generating embeddings for batch {i//batch_size + 1}...")
160
- batch_embeddings = generate_embeddings(batch_texts)
161
- all_embeddings.extend(batch_embeddings)
162
-
163
  st.info(f"Classifying emotions for batch {i//batch_size + 1}...")
164
  batch_emotions = [classify_emotion(text) for text in batch_texts]
165
  all_emotions.extend(batch_emotions)
166
 
167
  try:
168
- embeddings = np.array(all_embeddings)
169
-
170
  st.info(f"Fitting topic model for {country}...")
171
  topics, _ = topic_model.fit_transform(texts, embeddings)
172
 
173
- # Format topics and emotions with readable labels
174
  top_topics = format_topics(topic_model, Counter(topics).most_common(top_n))
175
  top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
176
 
@@ -186,46 +200,4 @@ def process_and_summarize(uploaded_file, top_n=50):
186
 
187
  return summaries, topic_model
188
 
189
- # Streamlit App Interface
190
- st.title("Arabic Poem Topic Modeling & Emotion Classification")
191
- st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.")
192
-
193
- uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
194
-
195
- if uploaded_file is not None:
196
- try:
197
- top_n = st.number_input("Select the number of top topics/emotions to display:",
198
- min_value=1, max_value=100, value=10)
199
-
200
- summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n)
201
- if summaries is not None:
202
- st.success("Data successfully processed!")
203
-
204
- # Display summary for each country
205
- for summary in summaries:
206
- st.write(f"### {summary['country']}")
207
- st.write(f"Total Poems: {summary['total_poems']}")
208
-
209
- st.write(f"\nTop {top_n} Topics:")
210
- for topic in summary['top_topics']:
211
- st.write(f"• {topic['topic']}: {topic['count']} poems")
212
-
213
- st.write(f"\nTop {top_n} Emotions:")
214
- for emotion in summary['top_emotions']:
215
- st.write(f"• {emotion['emotion']}: {emotion['count']} poems")
216
-
217
- st.write("---")
218
-
219
- # Display overall topics in a more readable format
220
- st.write("### Global Topic Information:")
221
- topic_info = topic_model.get_topic_info()
222
- for _, row in topic_info.iterrows():
223
- if row['Topic'] == -1:
224
- topic_name = "Miscellaneous"
225
- else:
226
- words = topic_model.get_topic(row['Topic'])
227
- topic_name = " | ".join([word for word, _ in words[:3]])
228
- st.write(f"• Topic {row['Topic']}: {topic_name} ({row['Count']} poems)")
229
-
230
- except Exception as e:
231
- st.error(f"Error: {str(e)}")
 
21
  'LABEL_2': 'Neutral'
22
  }
23
 
24
+ def chunk_long_text(text, max_length=512):
25
+ """
26
+ Split text into chunks respecting AraBERT's token limit.
27
+ Returns both tokenized chunks and decoded text chunks.
28
+ """
29
+ # Tokenize the entire text
30
  tokens = bert_tokenizer.encode(text, add_special_tokens=False)
31
  chunks = []
32
+ text_chunks = []
33
 
34
+ # Split into chunks of max_length-2 to account for [CLS] and [SEP]
35
+ for i in range(0, len(tokens), max_length-2):
36
+ chunk = tokens[i:i + max_length-2]
37
  # Add special tokens
38
+ full_chunk = [bert_tokenizer.cls_token_id] + chunk + [bert_tokenizer.sep_token_id]
39
+ chunks.append(full_chunk)
40
+ # Decode the chunk back to text (without special tokens)
41
+ text_chunks.append(bert_tokenizer.decode(chunk))
42
 
43
+ return chunks, text_chunks
44
 
45
  def get_embedding_for_text(text):
46
+ """Get embedding for a text, handling long sequences by averaging chunk embeddings."""
47
+ _, text_chunks = chunk_long_text(text)
48
  chunk_embeddings = []
49
 
50
+ for chunk in text_chunks:
51
+ # Encode chunk with padding and attention mask
52
+ inputs = bert_tokenizer(chunk,
53
+ return_tensors="pt",
54
+ padding=True,
55
+ truncation=True,
56
+ max_length=512)
57
+ inputs = {k: v.to(bert_model.device) for k, v in inputs.items()}
58
 
59
  with torch.no_grad():
60
+ outputs = bert_model(**inputs)
61
 
62
+ # Get [CLS] token embedding
63
+ embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
64
+ chunk_embeddings.append(embedding[0])
65
 
66
  # Average embeddings from all chunks
67
  if chunk_embeddings:
68
  return np.mean(chunk_embeddings, axis=0)
69
+ return np.zeros(bert_model.config.hidden_size)
70
 
71
  def generate_embeddings(texts):
72
  """Generate embeddings for a list of texts."""
 
78
  embeddings.append(embedding)
79
  except Exception as e:
80
  st.warning(f"Error processing text: {str(e)}")
 
81
  embeddings.append(np.zeros(bert_model.config.hidden_size))
82
 
83
  return np.array(embeddings)
84
 
85
  def classify_emotion(text):
86
+ """
87
+ Classify emotion for a text, handling long sequences by voting among chunks.
88
+ """
89
  try:
90
+ _, text_chunks = chunk_long_text(text)
91
+ chunk_emotions = []
92
+
93
+ for chunk in text_chunks:
94
+ result = emotion_classifier(chunk, max_length=512, truncation=True)[0]
95
+ chunk_emotions.append(result['label'])
96
+
97
+ # Use majority voting for final emotion
98
+ if chunk_emotions:
99
+ final_emotion = Counter(chunk_emotions).most_common(1)[0][0]
100
+ return final_emotion
101
+ return "unknown"
102
 
 
 
 
 
103
  except Exception as e:
104
  st.warning(f"Error in emotion classification: {str(e)}")
105
  return "unknown"
 
111
  if topic_num == -1:
112
  topic_label = "Miscellaneous"
113
  else:
 
114
  words = topic_model.get_topic(topic_num)
 
115
  topic_label = " | ".join([word for word, _ in words[:3]])
116
 
117
  formatted_topics.append({
 
152
  df['country'] = df['country'].str.strip()
153
  df = df.dropna(subset=['country', 'poem'])
154
 
155
+ # Initialize BERTopic with specific parameters for Arabic
156
  topic_model = BERTopic(
157
  language="arabic",
158
  calculate_probabilities=True,
159
+ min_topic_size=5,
160
  verbose=True
161
  )
162
 
 
168
  texts = group['poem'].dropna().tolist()
169
  batch_size = 10
170
  all_emotions = []
 
171
 
172
+ # Generate embeddings for all texts
173
+ st.info("Generating embeddings...")
174
+ embeddings = generate_embeddings(texts)
175
+
176
+ # Process emotions in batches
177
  for i in range(0, len(texts), batch_size):
178
  batch_texts = texts[i:i + batch_size]
 
 
 
 
 
179
  st.info(f"Classifying emotions for batch {i//batch_size + 1}...")
180
  batch_emotions = [classify_emotion(text) for text in batch_texts]
181
  all_emotions.extend(batch_emotions)
182
 
183
  try:
 
 
184
  st.info(f"Fitting topic model for {country}...")
185
  topics, _ = topic_model.fit_transform(texts, embeddings)
186
 
187
+ # Format results
188
  top_topics = format_topics(topic_model, Counter(topics).most_common(top_n))
189
  top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
190
 
 
200
 
201
  return summaries, topic_model
202
 
203
+ # Streamlit interface remains the same...