Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -195,12 +195,13 @@ def get_embedding_for_text(text, tokenizer, model):
|
|
195 |
continue
|
196 |
|
197 |
if chunk_embeddings:
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
204 |
def format_topics(topic_model, topic_counts):
|
205 |
"""Format topics for display."""
|
206 |
formatted_topics = []
|
@@ -252,41 +253,40 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
|
|
252 |
topic_model_params["nr_topics"] = "auto"
|
253 |
|
254 |
topic_model = BERTopic(
|
255 |
-
embedding_model=
|
256 |
**topic_model_params
|
257 |
)
|
258 |
|
259 |
-
vectorizer = CountVectorizer(
|
260 |
-
|
261 |
-
|
|
|
|
|
262 |
topic_model.vectorizer_model = vectorizer
|
263 |
|
264 |
-
# Create a placeholder for the progress bar
|
265 |
progress_placeholder = st.empty()
|
266 |
progress_bar = progress_placeholder.progress(0)
|
267 |
-
|
268 |
-
# Create status message placeholder
|
269 |
status_message = st.empty()
|
270 |
|
271 |
for country, group in df.groupby('country'):
|
272 |
-
# Clear memory at the start of each country's processing
|
273 |
gc.collect()
|
274 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
275 |
|
276 |
status_message.text(f"Processing poems for {country}...")
|
277 |
texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
|
278 |
all_emotions = []
|
|
|
279 |
|
280 |
-
# Use cached embeddings with progress tracking
|
281 |
-
embeddings = []
|
282 |
total_texts = len(texts)
|
283 |
for i, text in enumerate(texts):
|
284 |
try:
|
285 |
embedding = cache_embeddings(text, bert_tokenizer, bert_model)
|
286 |
if embedding is not None and not np.isnan(embedding).any():
|
287 |
-
|
|
|
|
|
|
|
288 |
|
289 |
-
# Update progress more frequently
|
290 |
if i % max(1, total_texts // 100) == 0:
|
291 |
progress = (i + 1) / total_texts * 0.4
|
292 |
progress_bar.progress(progress)
|
@@ -296,7 +296,7 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
|
|
296 |
st.warning(f"Error processing poem {i+1} in {country}: {str(e)}")
|
297 |
continue
|
298 |
|
299 |
-
# Process emotions
|
300 |
for i, text in enumerate(texts):
|
301 |
try:
|
302 |
emotion = cache_emotion_classification(text, emotion_classifier)
|
@@ -316,30 +316,32 @@ def process_and_summarize(df, bert_tokenizer, bert_model, emotion_classifier, to
|
|
316 |
st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
|
317 |
continue
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
334 |
except Exception as e:
|
335 |
st.warning(f"Could not generate topics for {country}: {str(e)}")
|
336 |
continue
|
337 |
|
338 |
-
# Clear progress for next country
|
339 |
progress_placeholder.empty()
|
340 |
status_message.empty()
|
341 |
-
|
342 |
-
# Create new progress bar for next country
|
343 |
progress_placeholder = st.empty()
|
344 |
progress_bar = progress_placeholder.progress(0)
|
345 |
status_message = st.empty()
|
|
|
195 |
continue
|
196 |
|
197 |
if chunk_embeddings:
|
198 |
+
# Convert to numpy array and ensure 2D shape
|
199 |
+
chunk_embeddings = np.array(chunk_embeddings)
|
200 |
+
if len(chunk_embeddings.shape) == 1:
|
201 |
+
chunk_embeddings = chunk_embeddings.reshape(1, -1)
|
202 |
+
return chunk_embeddings
|
203 |
+
return np.zeros((1, model.config.hidden_size))
|
204 |
+
|
205 |
def format_topics(topic_model, topic_counts):
|
206 |
"""Format topics for display."""
|
207 |
formatted_topics = []
|
|
|
253 |
topic_model_params["nr_topics"] = "auto"
|
254 |
|
255 |
topic_model = BERTopic(
|
256 |
+
embedding_model=None, # Set to None since we're providing embeddings
|
257 |
**topic_model_params
|
258 |
)
|
259 |
|
260 |
+
vectorizer = CountVectorizer(
|
261 |
+
stop_words=list(ARABIC_STOP_WORDS),
|
262 |
+
min_df=1,
|
263 |
+
max_df=1.0
|
264 |
+
)
|
265 |
topic_model.vectorizer_model = vectorizer
|
266 |
|
|
|
267 |
progress_placeholder = st.empty()
|
268 |
progress_bar = progress_placeholder.progress(0)
|
|
|
|
|
269 |
status_message = st.empty()
|
270 |
|
271 |
for country, group in df.groupby('country'):
|
|
|
272 |
gc.collect()
|
273 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
274 |
|
275 |
status_message.text(f"Processing poems for {country}...")
|
276 |
texts = [clean_arabic_text(poem) for poem in group['poem'].dropna()]
|
277 |
all_emotions = []
|
278 |
+
embeddings_list = []
|
279 |
|
|
|
|
|
280 |
total_texts = len(texts)
|
281 |
for i, text in enumerate(texts):
|
282 |
try:
|
283 |
embedding = cache_embeddings(text, bert_tokenizer, bert_model)
|
284 |
if embedding is not None and not np.isnan(embedding).any():
|
285 |
+
# Ensure embedding is 2D
|
286 |
+
if len(embedding.shape) == 1:
|
287 |
+
embedding = embedding.reshape(1, -1)
|
288 |
+
embeddings_list.append(embedding)
|
289 |
|
|
|
290 |
if i % max(1, total_texts // 100) == 0:
|
291 |
progress = (i + 1) / total_texts * 0.4
|
292 |
progress_bar.progress(progress)
|
|
|
296 |
st.warning(f"Error processing poem {i+1} in {country}: {str(e)}")
|
297 |
continue
|
298 |
|
299 |
+
# Process emotions
|
300 |
for i, text in enumerate(texts):
|
301 |
try:
|
302 |
emotion = cache_emotion_classification(text, emotion_classifier)
|
|
|
316 |
st.warning(f"Not enough documents for {country} to generate meaningful topics (minimum {min_topic_size} required)")
|
317 |
continue
|
318 |
|
319 |
+
if embeddings_list:
|
320 |
+
# Stack all embeddings into a single 2D array
|
321 |
+
embeddings = np.vstack(embeddings_list)
|
322 |
+
|
323 |
+
topics, probs = topic_model.fit_transform(texts, embeddings)
|
324 |
+
topic_counts = Counter(topics)
|
325 |
+
|
326 |
+
top_topics = format_topics(topic_model, topic_counts.most_common(top_n))
|
327 |
+
top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
|
328 |
+
|
329 |
+
summaries.append({
|
330 |
+
'country': country,
|
331 |
+
'total_poems': len(texts),
|
332 |
+
'top_topics': top_topics,
|
333 |
+
'top_emotions': top_emotions
|
334 |
+
})
|
335 |
+
progress_bar.progress(1.0, text="Processing complete!")
|
336 |
+
else:
|
337 |
+
st.warning(f"No valid embeddings generated for {country}")
|
338 |
+
|
339 |
except Exception as e:
|
340 |
st.warning(f"Could not generate topics for {country}: {str(e)}")
|
341 |
continue
|
342 |
|
|
|
343 |
progress_placeholder.empty()
|
344 |
status_message.empty()
|
|
|
|
|
345 |
progress_placeholder = st.empty()
|
346 |
progress_bar = progress_placeholder.progress(0)
|
347 |
status_message = st.empty()
|