Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline | |
from bertopic import BERTopic | |
import torch | |
import numpy as np | |
from collections import Counter | |
import os | |
# Configure page | |
st.set_page_config( | |
page_title="Arabic Poem Analysis", | |
page_icon="๐", | |
layout="wide" | |
) | |
def load_models(): | |
"""Load and cache the models to prevent reloading""" | |
bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2") | |
bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2") | |
emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment") | |
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer) | |
return bert_tokenizer, bert_model, emotion_classifier | |
# Load models | |
try: | |
bert_tokenizer, bert_model, emotion_classifier = load_models() | |
st.success("Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.stop() | |
# Define emotion labels mapping | |
EMOTION_LABELS = { | |
'LABEL_0': 'Negative', | |
'LABEL_1': 'Positive', | |
'LABEL_2': 'Neutral' | |
} | |
def chunk_long_text(text, tokenizer, max_length=512): | |
"""Split text into chunks respecting token limit.""" | |
tokens = tokenizer.encode(text, add_special_tokens=False) | |
chunks = [] | |
text_chunks = [] | |
for i in range(0, len(tokens), max_length-2): | |
chunk = tokens[i:i + max_length-2] | |
full_chunk = [tokenizer.cls_token_id] + chunk + [tokenizer.sep_token_id] | |
chunks.append(full_chunk) | |
text_chunks.append(tokenizer.decode(chunk)) | |
return chunks, text_chunks | |
def get_embedding_for_text(text, tokenizer, model): | |
"""Get embedding for a text, handling long sequences.""" | |
_, text_chunks = chunk_long_text(text, tokenizer) | |
chunk_embeddings = [] | |
for chunk in text_chunks: | |
inputs = tokenizer(chunk, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
chunk_embeddings.append(embedding[0]) | |
if chunk_embeddings: | |
return np.mean(chunk_embeddings, axis=0) | |
return np.zeros(model.config.hidden_size) | |
def generate_embeddings(texts, tokenizer, model): | |
"""Generate embeddings for a list of texts.""" | |
embeddings = [] | |
for text in texts: | |
try: | |
embedding = get_embedding_for_text(text, tokenizer, model) | |
embeddings.append(embedding) | |
except Exception as e: | |
st.warning(f"Error processing text: {str(e)}") | |
embeddings.append(np.zeros(model.config.hidden_size)) | |
return np.array(embeddings) | |
def classify_emotion(text, tokenizer, classifier): | |
"""Classify emotion for a text using majority voting.""" | |
try: | |
_, text_chunks = chunk_long_text(text, tokenizer) | |
chunk_emotions = [] | |
for chunk in text_chunks: | |
result = classifier(chunk, max_length=512, truncation=True)[0] | |
chunk_emotions.append(result['label']) | |
if chunk_emotions: | |
final_emotion = Counter(chunk_emotions).most_common(1)[0][0] | |
return final_emotion | |
return "unknown" | |
except Exception as e: | |
st.warning(f"Error in emotion classification: {str(e)}") | |
return "unknown" | |
def format_topics(topic_model, topic_counts): | |
"""Format topics for display.""" | |
formatted_topics = [] | |
for topic_num, count in topic_counts: | |
if topic_num == -1: | |
topic_label = "Miscellaneous" | |
else: | |
words = topic_model.get_topic(topic_num) | |
topic_label = " | ".join([word for word, _ in words[:3]]) | |
formatted_topics.append({ | |
'topic': topic_label, | |
'count': count | |
}) | |
return formatted_topics | |
def format_emotions(emotion_counts): | |
"""Format emotions for display.""" | |
formatted_emotions = [] | |
for label, count in emotion_counts: | |
emotion = EMOTION_LABELS.get(label, label) | |
formatted_emotions.append({ | |
'emotion': emotion, | |
'count': count | |
}) | |
return formatted_emotions | |
def process_and_summarize(df, top_n=50): | |
"""Process the data and generate summaries.""" | |
summaries = [] | |
# Initialize BERTopic | |
topic_model = BERTopic( | |
language="arabic", | |
calculate_probabilities=True, | |
min_topic_size=5, | |
verbose=True | |
) | |
# Group by country | |
for country, group in df.groupby('country'): | |
progress_text = f"Processing poems for {country}..." | |
progress_bar = st.progress(0, text=progress_text) | |
texts = group['poem'].dropna().tolist() | |
batch_size = 10 | |
all_emotions = [] | |
# Generate embeddings | |
embeddings = generate_embeddings(texts, bert_tokenizer, bert_model) | |
progress_bar.progress(0.33, text="Generating embeddings...") | |
# Process emotions | |
for i in range(0, len(texts), batch_size): | |
batch_texts = texts[i:i + batch_size] | |
batch_emotions = [classify_emotion(text, bert_tokenizer, emotion_classifier) | |
for text in batch_texts] | |
all_emotions.extend(batch_emotions) | |
progress_bar.progress(0.66, text="Classifying emotions...") | |
try: | |
# Fit topic model | |
topics, _ = topic_model.fit_transform(texts, embeddings) | |
# Format results | |
top_topics = format_topics(topic_model, Counter(topics).most_common(top_n)) | |
top_emotions = format_emotions(Counter(all_emotions).most_common(top_n)) | |
summaries.append({ | |
'country': country, | |
'total_poems': len(texts), | |
'top_topics': top_topics, | |
'top_emotions': top_emotions | |
}) | |
progress_bar.progress(1.0, text="Processing complete!") | |
except Exception as e: | |
st.warning(f"Could not generate topics for {country}: {str(e)}") | |
continue | |
return summaries, topic_model | |
# Main app interface | |
st.title("๐ Arabic Poem Analysis") | |
st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.") | |
# File upload | |
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"]) | |
if uploaded_file is not None: | |
try: | |
# Read the file | |
if uploaded_file.name.endswith('.csv'): | |
df = pd.read_csv(uploaded_file) | |
else: | |
df = pd.read_excel(uploaded_file) | |
# Validate columns | |
required_columns = ['country', 'poem'] | |
if not all(col in df.columns for col in required_columns): | |
st.error("File must contain 'country' and 'poem' columns.") | |
st.stop() | |
# Clean data | |
df['country'] = df['country'].str.strip() | |
df = df.dropna(subset=['country', 'poem']) | |
# Process data | |
top_n = st.number_input("Number of top topics/emotions to display:", | |
min_value=1, max_value=100, value=10) | |
if st.button("Process Data"): | |
with st.spinner("Processing your data..."): | |
summaries, topic_model = process_and_summarize(df, top_n=top_n) | |
if summaries: | |
st.success("Analysis complete!") | |
# Display results in tabs | |
tab1, tab2 = st.tabs(["Country Summaries", "Global Topics"]) | |
with tab1: | |
for summary in summaries: | |
with st.expander(f"๐ {summary['country']} ({summary['total_poems']} poems)"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Top Topics") | |
for topic in summary['top_topics']: | |
st.write(f"โข {topic['topic']}: {topic['count']} poems") | |
with col2: | |
st.subheader("Emotions") | |
for emotion in summary['top_emotions']: | |
st.write(f"โข {emotion['emotion']}: {emotion['count']} poems") | |
with tab2: | |
st.subheader("Global Topic Distribution") | |
topic_info = topic_model.get_topic_info() | |
for _, row in topic_info.iterrows(): | |
if row['Topic'] == -1: | |
topic_name = "Miscellaneous" | |
else: | |
words = topic_model.get_topic(row['Topic']) | |
topic_name = " | ".join([word for word, _ in words[:3]]) | |
st.write(f"โข Topic {row['Topic']}: {topic_name} ({row['Count']} poems)") | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
else: | |
st.info("๐ Upload a file to get started!") | |
# Example format | |
st.write("### Expected File Format:") | |
example_df = pd.DataFrame({ | |
'country': ['Egypt', 'Saudi Arabia'], | |
'poem': ['ูุตูุฏุฉ ู ุตุฑูุฉ', 'ูุตูุฏุฉ ุณุนูุฏูุฉ'] | |
}) | |
st.dataframe(example_df) | |