SoLProject / app.py
Last commit not found
raw
history blame
7.14 kB
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from bertopic import BERTopic
import torch
import numpy as np
from collections import Counter
# Load AraBERT tokenizer and model for embeddings
bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
# Load AraBERT model for emotion classification
emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
# Define emotion labels mapping
EMOTION_LABELS = {
'LABEL_0': 'Negative',
'LABEL_1': 'Positive',
'LABEL_2': 'Neutral'
}
def chunk_long_text(text, max_length=512):
"""
Split text into chunks respecting AraBERT's token limit.
Returns both tokenized chunks and decoded text chunks.
"""
# Tokenize the entire text
tokens = bert_tokenizer.encode(text, add_special_tokens=False)
chunks = []
text_chunks = []
# Split into chunks of max_length-2 to account for [CLS] and [SEP]
for i in range(0, len(tokens), max_length-2):
chunk = tokens[i:i + max_length-2]
# Add special tokens
full_chunk = [bert_tokenizer.cls_token_id] + chunk + [bert_tokenizer.sep_token_id]
chunks.append(full_chunk)
# Decode the chunk back to text (without special tokens)
text_chunks.append(bert_tokenizer.decode(chunk))
return chunks, text_chunks
def get_embedding_for_text(text):
"""Get embedding for a text, handling long sequences by averaging chunk embeddings."""
_, text_chunks = chunk_long_text(text)
chunk_embeddings = []
for chunk in text_chunks:
# Encode chunk with padding and attention mask
inputs = bert_tokenizer(chunk,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512)
inputs = {k: v.to(bert_model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = bert_model(**inputs)
# Get [CLS] token embedding
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
chunk_embeddings.append(embedding[0])
# Average embeddings from all chunks
if chunk_embeddings:
return np.mean(chunk_embeddings, axis=0)
return np.zeros(bert_model.config.hidden_size)
def generate_embeddings(texts):
"""Generate embeddings for a list of texts."""
embeddings = []
for text in texts:
try:
embedding = get_embedding_for_text(text)
embeddings.append(embedding)
except Exception as e:
st.warning(f"Error processing text: {str(e)}")
embeddings.append(np.zeros(bert_model.config.hidden_size))
return np.array(embeddings)
def classify_emotion(text):
"""
Classify emotion for a text, handling long sequences by voting among chunks.
"""
try:
_, text_chunks = chunk_long_text(text)
chunk_emotions = []
for chunk in text_chunks:
result = emotion_classifier(chunk, max_length=512, truncation=True)[0]
chunk_emotions.append(result['label'])
# Use majority voting for final emotion
if chunk_emotions:
final_emotion = Counter(chunk_emotions).most_common(1)[0][0]
return final_emotion
return "unknown"
except Exception as e:
st.warning(f"Error in emotion classification: {str(e)}")
return "unknown"
def format_topics(topic_model, topic_counts):
"""Convert topic numbers to readable labels."""
formatted_topics = []
for topic_num, count in topic_counts:
if topic_num == -1:
topic_label = "Miscellaneous"
else:
words = topic_model.get_topic(topic_num)
topic_label = " | ".join([word for word, _ in words[:3]])
formatted_topics.append({
'topic': topic_label,
'count': count
})
return formatted_topics
def format_emotions(emotion_counts):
"""Convert emotion labels to readable text."""
formatted_emotions = []
for label, count in emotion_counts:
emotion = EMOTION_LABELS.get(label, label)
formatted_emotions.append({
'emotion': emotion,
'count': count
})
return formatted_emotions
def process_and_summarize(uploaded_file, top_n=50):
# Determine the file type
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
df = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file format.")
return None, None
# Validate required columns
required_columns = ['country', 'poem']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing columns: {', '.join(missing_columns)}")
return None, None
# Parse and preprocess the file
df['country'] = df['country'].str.strip()
df = df.dropna(subset=['country', 'poem'])
# Initialize BERTopic with specific parameters for Arabic
topic_model = BERTopic(
language="arabic",
calculate_probabilities=True,
min_topic_size=5,
verbose=True
)
# Group by country
summaries = []
for country, group in df.groupby('country'):
st.info(f"Processing poems for {country}...")
texts = group['poem'].dropna().tolist()
batch_size = 10
all_emotions = []
# Generate embeddings for all texts
st.info("Generating embeddings...")
embeddings = generate_embeddings(texts)
# Process emotions in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
st.info(f"Classifying emotions for batch {i//batch_size + 1}...")
batch_emotions = [classify_emotion(text) for text in batch_texts]
all_emotions.extend(batch_emotions)
try:
st.info(f"Fitting topic model for {country}...")
topics, _ = topic_model.fit_transform(texts, embeddings)
# Format results
top_topics = format_topics(topic_model, Counter(topics).most_common(top_n))
top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
summaries.append({
'country': country,
'total_poems': len(texts),
'top_topics': top_topics,
'top_emotions': top_emotions
})
except Exception as e:
st.warning(f"Could not generate topics for {country}: {str(e)}")
continue
return summaries, topic_model
# Streamlit interface remains the same...