Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline | |
from bertopic import BERTopic | |
import torch | |
import numpy as np | |
from collections import Counter | |
import os | |
# Configure page | |
st.set_page_config( | |
page_title="Arabic Poem Analysis", | |
page_icon="๐", | |
layout="wide" | |
) | |
def load_models(): | |
"""Load and cache the models to prevent reloading""" | |
# Use CAMeL-Lab's tokenizer for consistency with the emotion model | |
tokenizer = AutoTokenizer.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment") | |
bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2") | |
emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment") | |
emotion_tokenizer = AutoTokenizer.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment") | |
emotion_classifier = pipeline( | |
"sentiment-analysis", | |
model=emotion_model, | |
tokenizer=emotion_tokenizer, | |
return_all_scores=True | |
) | |
return tokenizer, bert_model, emotion_classifier | |
def split_text(text, max_length=512): | |
"""Split text into chunks of maximum token length while preserving word boundaries.""" | |
words = text.split() | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for word in words: | |
word_length = len(word.split()) | |
if current_length + word_length > max_length: | |
if current_chunk: # Only append if there are words in the current chunk | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [word] | |
current_length = word_length | |
else: | |
current_chunk.append(word) | |
current_length += word_length | |
if current_chunk: # Append the last chunk if it exists | |
chunks.append(' '.join(current_chunk)) | |
return chunks | |
# The beginning of the code remains the same until the classify_emotion function | |
def classify_emotion(text, classifier): | |
"""Classify emotion for complete text with proper token handling.""" | |
try: | |
# Split text into manageable chunks | |
words = text.split() | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
# Create chunks that respect the 512 token limit | |
for word in words: | |
# Add word length plus 1 for space | |
word_tokens = len(classifier.tokenizer.encode(word)) | |
if current_length + word_tokens > 512: | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [word] | |
current_length = word_tokens | |
else: | |
current_chunk.append(word) | |
current_length += word_tokens | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
# If no chunks were created, use the original text with truncation | |
if not chunks: | |
chunks = [text] | |
all_scores = [] | |
for chunk in chunks: | |
try: | |
# Ensure proper truncation | |
inputs = classifier.tokenizer( | |
chunk, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
) | |
result = classifier(chunk, truncation=True, max_length=512) | |
scores = result[0] | |
all_scores.append(scores) | |
except Exception as chunk_error: | |
st.warning(f"Skipping chunk due to error: {str(chunk_error)}") | |
continue | |
# Average scores across all chunks | |
if all_scores: | |
# Create a dictionary to store summed scores for each label | |
label_scores = {} | |
count = len(all_scores) | |
# Sum up scores for each label | |
for scores in all_scores: | |
for score in scores: | |
label = score['label'] | |
if label not in label_scores: | |
label_scores[label] = 0 | |
label_scores[label] += score['score'] | |
# Calculate averages | |
avg_scores = {label: score/count for label, score in label_scores.items()} | |
# Get the label with highest average score | |
final_emotion = max(avg_scores.items(), key=lambda x: x[1])[0] | |
return final_emotion | |
return "LABEL_2" # Default to neutral if no valid results | |
except Exception as e: | |
st.warning(f"Error in emotion classification: {str(e)}") | |
return "LABEL_2" # Default to neutral | |
def get_embedding_for_text(text, tokenizer, model): | |
"""Get embedding for complete text.""" | |
chunks = split_text(text) | |
chunk_embeddings = [] | |
for chunk in chunks: | |
try: | |
inputs = tokenizer( | |
chunk, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
chunk_embeddings.append(embedding[0]) | |
except Exception as e: | |
st.warning(f"Error processing chunk: {str(e)}") | |
continue | |
if chunk_embeddings: | |
# Use weighted average based on chunk length | |
weights = np.array([len(chunk.split()) for chunk in chunks]) | |
weights = weights / weights.sum() | |
weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights) | |
return weighted_embedding | |
return np.zeros(model.config.hidden_size) | |
def format_topics(topic_model, topic_counts): | |
"""Format topics for display.""" | |
formatted_topics = [] | |
for topic_num, count in topic_counts: | |
if topic_num == -1: | |
topic_label = "Miscellaneous" | |
else: | |
words = topic_model.get_topic(topic_num) | |
topic_label = " | ".join([word for word, _ in words[:5]]) # Show top 5 words per topic | |
formatted_topics.append({ | |
'topic': topic_label, | |
'count': count | |
}) | |
return formatted_topics | |
def format_emotions(emotion_counts): | |
"""Format emotions for display.""" | |
# Define emotion labels mapping | |
EMOTION_LABELS = { | |
'LABEL_0': 'Negative', | |
'LABEL_1': 'Positive', | |
'LABEL_2': 'Neutral' | |
} | |
formatted_emotions = [] | |
for label, count in emotion_counts: | |
emotion = EMOTION_LABELS.get(label, label) | |
formatted_emotions.append({ | |
'emotion': emotion, | |
'count': count | |
}) | |
return formatted_emotions | |
def process_and_summarize(df, top_n=50): | |
"""Process the data and generate summaries.""" | |
summaries = [] | |
# Initialize BERTopic with Arabic-specific settings | |
topic_model = BERTopic( | |
language="multilingual", | |
calculate_probabilities=True, | |
min_topic_size=2, # Allow smaller topic groups | |
n_gram_range=(1, 3), # Include up to trigrams | |
top_n_words=15, # Show more words per topic | |
verbose=True | |
) | |
# Group by country | |
for country, group in df.groupby('country'): | |
progress_text = f"Processing poems for {country}..." | |
progress_bar = st.progress(0, text=progress_text) | |
texts = group['poem'].dropna().tolist() | |
all_emotions = [] | |
# Generate embeddings with progress tracking | |
embeddings = [] | |
for i, text in enumerate(texts): | |
embedding = get_embedding_for_text(text, bert_tokenizer, bert_model) | |
embeddings.append(embedding) | |
progress = (i + 1) / len(texts) * 0.4 | |
progress_bar.progress(progress, text=f"Generated embeddings for {i+1}/{len(texts)} poems...") | |
embeddings = np.array(embeddings) | |
# Process emotions with progress tracking | |
for i, text in enumerate(texts): | |
emotion = classify_emotion(text, emotion_classifier) | |
all_emotions.append(emotion) | |
progress = 0.4 + ((i + 1) / len(texts) * 0.3) | |
progress_bar.progress(progress, text=f"Classified emotions for {i+1}/{len(texts)} poems...") | |
try: | |
# Fit topic model | |
topics, _ = topic_model.fit_transform(texts, embeddings) | |
# Format results | |
top_topics = format_topics(topic_model, Counter(topics).most_common(top_n)) | |
top_emotions = format_emotions(Counter(all_emotions).most_common(top_n)) | |
summaries.append({ | |
'country': country, | |
'total_poems': len(texts), | |
'top_topics': top_topics, | |
'top_emotions': top_emotions | |
}) | |
progress_bar.progress(1.0, text="Processing complete!") | |
except Exception as e: | |
st.warning(f"Could not generate topics for {country}: {str(e)}") | |
continue | |
return summaries, topic_model | |
# Load models | |
try: | |
bert_tokenizer, bert_model, emotion_classifier = load_models() | |
st.success("Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.stop() | |
# Main app interface | |
st.title("๐ Arabic Poem Analysis") | |
st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.") | |
# File upload | |
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"]) | |
if uploaded_file is not None: | |
try: | |
# Read the file | |
if uploaded_file.name.endswith('.csv'): | |
df = pd.read_csv(uploaded_file) | |
else: | |
df = pd.read_excel(uploaded_file) | |
# Validate columns | |
required_columns = ['country', 'poem'] | |
if not all(col in df.columns for col in required_columns): | |
st.error("File must contain 'country' and 'poem' columns.") | |
st.stop() | |
# Clean data | |
df['country'] = df['country'].str.strip() | |
df = df.dropna(subset=['country', 'poem']) | |
# Process data | |
top_n = st.number_input("Number of top topics/emotions to display:", | |
min_value=1, max_value=100, value=10) | |
if st.button("Process Data"): | |
with st.spinner("Processing your data..."): | |
summaries, topic_model = process_and_summarize(df, top_n=top_n) | |
if summaries: | |
st.success("Analysis complete!") | |
# Display results in tabs | |
tab1, tab2 = st.tabs(["Country Summaries", "Global Topics"]) | |
with tab1: | |
for summary in summaries: | |
with st.expander(f"๐ {summary['country']} ({summary['total_poems']} poems)"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Top Topics") | |
for topic in summary['top_topics']: | |
st.write(f"โข {topic['topic']}: {topic['count']} poems") | |
with col2: | |
st.subheader("Emotions") | |
for emotion in summary['top_emotions']: | |
st.write(f"โข {emotion['emotion']}: {emotion['count']} poems") | |
with tab2: | |
st.subheader("Global Topic Distribution") | |
topic_info = topic_model.get_topic_info() | |
for _, row in topic_info.iterrows(): | |
if row['Topic'] == -1: | |
topic_name = "Miscellaneous" | |
else: | |
words = topic_model.get_topic(row['Topic']) | |
topic_name = " | ".join([word for word, _ in words[:5]]) | |
st.write(f"โข Topic {row['Topic']}: {topic_name} ({row['Count']} poems)") | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
else: | |
st.info("๐ Upload a file to get started!") | |
# Example format | |
st.write("### Expected File Format:") | |
example_df = pd.DataFrame({ | |
'country': ['Egypt', 'Palestine'], | |
'poem': ['ูุตูุฏุฉ ู ุตุฑูุฉ', 'ูุตูุฏุฉ ููุณุทูููุฉ '] | |
}) | |
st.dataframe(example_df) |