SoLProject / app.py
kambris's picture
Update app.py
7684baa verified
raw
history blame
4.06 kB
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from bertopic import BERTopic
import torch
from collections import Counter
# Load AraBERT tokenizer and model for embeddings
bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
# Load AraBERT model for emotion classification
emotion_model = AutoModelForSequenceClassification.from_pretrained("aubmindlab/bert-base-arabertv2")
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
# Function to generate embeddings using AraBERT
def generate_embeddings(texts):
inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = bert_model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).numpy()
return embeddings
# Function to process the uploaded file and summarize by country
def process_and_summarize(uploaded_file, top_n=50):
# Determine the file type
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
df = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file format.")
return None, None
# Validate required columns
required_columns = ['country', 'poem']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing columns: {', '.join(missing_columns)}")
return None, None
# Parse and preprocess the file
df['country'] = df['country'].str.strip()
df = df.dropna(subset=['country', 'poem'])
# Group by country
summaries = []
topic_model = BERTopic()
for country, group in df.groupby('country'):
st.info(f"Processing poems for {country}...")
# Combine all poems for the country
texts = group['poem'].dropna().tolist()
# Classify emotions
st.info(f"Classifying emotions for {country}...")
emotions = [emotion_classifier(text)[0]['label'] for text in texts]
# Generate embeddings and fit topic model
st.info(f"Generating embeddings and topics for {country}...")
embeddings = generate_embeddings(texts)
topics, _ = topic_model.fit_transform(embeddings)
# Aggregate topics and emotions
top_topics = Counter(topics).most_common(top_n)
top_emotions = Counter(emotions).most_common(top_n)
summaries.append({
'country': country,
'total_poems': len(texts),
'top_topics': top_topics,
'top_emotions': top_emotions
})
return summaries, topic_model
# Streamlit App Interface
st.title("Arabic Poem Topic Modeling & Emotion Classification")
st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.")
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
try:
top_n = st.number_input("Select the number of top topics/emotions to display:", min_value=1, max_value=100, value=50)
summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n)
if summaries is not None:
st.success("Data successfully processed!")
# Display summary for each country
for summary in summaries:
st.write(f"### {summary['country']}")
st.write(f"Total Poems: {summary['total_poems']}")
st.write(f"Top {top_n} Topics:")
st.write(summary['top_topics'])
st.write(f"Top {top_n} Emotions:")
st.write(summary['top_emotions'])
# Display overall topics
st.write("### Global Topic Information:")
st.write(topic_model.get_topic_info())
except Exception as e:
st.error(f"Error: {e}")