Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline | |
from bertopic import BERTopic | |
import torch | |
from collections import Counter | |
# Load AraBERT tokenizer and model for embeddings | |
bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2") | |
bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2") | |
# Load AraBERT model for emotion classification | |
emotion_model = AutoModelForSequenceClassification.from_pretrained("aubmindlab/bert-base-arabertv2") | |
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer) | |
# Function to generate embeddings using AraBERT | |
def generate_embeddings(texts): | |
inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = bert_model(**inputs) | |
embeddings = outputs.last_hidden_state.mean(dim=1).numpy() | |
return embeddings | |
# Function to process the uploaded file and summarize by country | |
def process_and_summarize(uploaded_file, top_n=50): | |
# Determine the file type | |
if uploaded_file.name.endswith(".csv"): | |
df = pd.read_csv(uploaded_file) | |
elif uploaded_file.name.endswith(".xlsx"): | |
df = pd.read_excel(uploaded_file) | |
else: | |
st.error("Unsupported file format.") | |
return None, None | |
# Validate required columns | |
required_columns = ['country', 'poem'] | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error(f"Missing columns: {', '.join(missing_columns)}") | |
return None, None | |
# Parse and preprocess the file | |
df['country'] = df['country'].str.strip() | |
df = df.dropna(subset=['country', 'poem']) | |
# Group by country | |
summaries = [] | |
topic_model = BERTopic() | |
for country, group in df.groupby('country'): | |
st.info(f"Processing poems for {country}...") | |
# Combine all poems for the country | |
texts = group['poem'].dropna().tolist() | |
# Classify emotions | |
st.info(f"Classifying emotions for {country}...") | |
emotions = [emotion_classifier(text)[0]['label'] for text in texts] | |
# Generate embeddings and fit topic model | |
st.info(f"Generating embeddings and topics for {country}...") | |
embeddings = generate_embeddings(texts) | |
topics, _ = topic_model.fit_transform(embeddings) | |
# Aggregate topics and emotions | |
top_topics = Counter(topics).most_common(top_n) | |
top_emotions = Counter(emotions).most_common(top_n) | |
summaries.append({ | |
'country': country, | |
'total_poems': len(texts), | |
'top_topics': top_topics, | |
'top_emotions': top_emotions | |
}) | |
return summaries, topic_model | |
# Streamlit App Interface | |
st.title("Arabic Poem Topic Modeling & Emotion Classification") | |
st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.") | |
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"]) | |
if uploaded_file is not None: | |
try: | |
top_n = st.number_input("Select the number of top topics/emotions to display:", min_value=1, max_value=100, value=50) | |
summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n) | |
if summaries is not None: | |
st.success("Data successfully processed!") | |
# Display summary for each country | |
for summary in summaries: | |
st.write(f"### {summary['country']}") | |
st.write(f"Total Poems: {summary['total_poems']}") | |
st.write(f"Top {top_n} Topics:") | |
st.write(summary['top_topics']) | |
st.write(f"Top {top_n} Emotions:") | |
st.write(summary['top_emotions']) | |
# Display overall topics | |
st.write("### Global Topic Information:") | |
st.write(topic_model.get_topic_info()) | |
except Exception as e: | |
st.error(f"Error: {e}") | |