kambris commited on
Commit
4c9a0ea
·
verified ·
1 Parent(s): 89e32b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -37
app.py CHANGED
@@ -195,12 +195,13 @@ def get_embedding_for_text(text, tokenizer, model):
195
  continue
196
 
197
  if chunk_embeddings:
198
- weights = np.array([len(chunk.split()) for chunk in chunks])
199
- weights = weights / weights.sum()
200
- weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights)
201
- return weighted_embedding
202
- return np.zeros(model.config.hidden_size)
203
-
 
204
  def format_topics(topic_model, topic_counts):
205
  """Format topics for display."""
206
  formatted_topics = []
@@ -252,41 +253,40 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
252
  topic_model_params["nr_topics"] = "auto"
253
 
254
  topic_model = BERTopic(
255
- embedding_model=bert_model,
256
  **topic_model_params
257
  )
258
 
259
- vectorizer = CountVectorizer(stop_words=list(ARABIC_STOP_WORDS),
260
- min_df=1,
261
- max_df=1.0)
 
 
262
  topic_model.vectorizer_model = vectorizer
263
 
264
- # Create a placeholder for the progress bar
265
  progress_placeholder = st.empty()
266
  progress_bar = progress_placeholder.progress(0)
267
-
268
- # Create status message placeholder
269
  status_message = st.empty()
270
 
271
  for country, group in df.groupby('country'):
272
- # Clear memory at the start of each country's processing
273
  gc.collect()
274
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
275
 
276
  status_message.text(f"Processing poems for {country}...")
277
  texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
278
  all_emotions = []
 
279
 
280
- # Use cached embeddings with progress tracking
281
- embeddings = []
282
  total_texts = len(texts)
283
  for i, text in enumerate(texts):
284
  try:
285
  embedding = cache_embeddings(text, bert_tokenizer, bert_model)
286
  if embedding is not None and not np.isnan(embedding).any():
287
- embeddings.append(embedding)
 
 
 
288
 
289
- # Update progress more frequently
290
  if i % max(1, total_texts // 100) == 0:
291
  progress = (i + 1) / total_texts * 0.4
292
  progress_bar.progress(progress)
@@ -296,7 +296,7 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
296
  st.warning(f"Error processing poem {i+1} in {country}: {str(e)}")
297
  continue
298
 
299
- # Process emotions with caching and progress tracking
300
  for i, text in enumerate(texts):
301
  try:
302
  emotion = cache_emotion_classification(text, emotion_classifier)
@@ -316,30 +316,32 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
316
  st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
317
  continue
318
 
319
- topics, probs = topic_model.fit_transform(texts, embeddings)
320
-
321
- topic_counts = Counter(topics)
322
-
323
- top_topics = format_topics(topic_model, topic_counts.most_common(top_n))
324
- top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
325
-
326
- summaries.append({
327
- 'country': country,
328
- 'total_poems': len(texts),
329
- 'top_topics': top_topics,
330
- 'top_emotions': top_emotions
331
- })
332
- progress_bar.progress(1.0, text="Processing complete!")
333
-
 
 
 
 
 
334
  except Exception as e:
335
  st.warning(f"Could not generate topics for {country}: {str(e)}")
336
  continue
337
 
338
- # Clear progress for next country
339
  progress_placeholder.empty()
340
  status_message.empty()
341
-
342
- # Create new progress bar for next country
343
  progress_placeholder = st.empty()
344
  progress_bar = progress_placeholder.progress(0)
345
  status_message = st.empty()
 
195
  continue
196
 
197
  if chunk_embeddings:
198
+ # Convert to numpy array and ensure 2D shape
199
+ chunk_embeddings = np.array(chunk_embeddings)
200
+ if len(chunk_embeddings.shape) == 1:
201
+ chunk_embeddings = chunk_embeddings.reshape(1, -1)
202
+ return chunk_embeddings
203
+ return np.zeros((1, model.config.hidden_size))
204
+
205
  def format_topics(topic_model, topic_counts):
206
  """Format topics for display."""
207
  formatted_topics = []
 
253
  topic_model_params["nr_topics"] = "auto"
254
 
255
  topic_model = BERTopic(
256
+ embedding_model=None, # Set to None since we're providing embeddings
257
  **topic_model_params
258
  )
259
 
260
+ vectorizer = CountVectorizer(
261
+ stop_words=list(ARABIC_STOP_WORDS),
262
+ min_df=1,
263
+ max_df=1.0
264
+ )
265
  topic_model.vectorizer_model = vectorizer
266
 
 
267
  progress_placeholder = st.empty()
268
  progress_bar = progress_placeholder.progress(0)
 
 
269
  status_message = st.empty()
270
 
271
  for country, group in df.groupby('country'):
 
272
  gc.collect()
273
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
274
 
275
  status_message.text(f"Processing poems for {country}...")
276
  texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
277
  all_emotions = []
278
+ embeddings_list = []
279
 
 
 
280
  total_texts = len(texts)
281
  for i, text in enumerate(texts):
282
  try:
283
  embedding = cache_embeddings(text, bert_tokenizer, bert_model)
284
  if embedding is not None and not np.isnan(embedding).any():
285
+ # Ensure embedding is 2D
286
+ if len(embedding.shape) == 1:
287
+ embedding = embedding.reshape(1, -1)
288
+ embeddings_list.append(embedding)
289
 
 
290
  if i % max(1, total_texts // 100) == 0:
291
  progress = (i + 1) / total_texts * 0.4
292
  progress_bar.progress(progress)
 
296
  st.warning(f"Error processing poem {i+1} in {country}: {str(e)}")
297
  continue
298
 
299
+ # Process emotions
300
  for i, text in enumerate(texts):
301
  try:
302
  emotion = cache_emotion_classification(text, emotion_classifier)
 
316
  st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
317
  continue
318
 
319
+ if embeddings_list:
320
+ # Stack all embeddings into a single 2D array
321
+ embeddings = np.vstack(embeddings_list)
322
+
323
+ topics, probs = topic_model.fit_transform(texts, embeddings)
324
+ topic_counts = Counter(topics)
325
+
326
+ top_topics = format_topics(topic_model, topic_counts.most_common(top_n))
327
+ top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
328
+
329
+ summaries.append({
330
+ 'country': country,
331
+ 'total_poems': len(texts),
332
+ 'top_topics': top_topics,
333
+ 'top_emotions': top_emotions
334
+ })
335
+ progress_bar.progress(1.0, text="Processing complete!")
336
+ else:
337
+ st.warning(f"No valid embeddings generated for {country}")
338
+
339
  except Exception as e:
340
  st.warning(f"Could not generate topics for {country}: {str(e)}")
341
  continue
342
 
 
343
  progress_placeholder.empty()
344
  status_message.empty()
 
 
345
  progress_placeholder = st.empty()
346
  progress_bar = progress_placeholder.progress(0)
347
  status_message = st.empty()