SoLProject / app.py
kambris's picture
Update app.py
5fce9bd verified
raw
history blame
2.36 kB
import streamlit as st
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
from bertopic import BERTopic
import torch
# Initialize ARAT5 model and tokenizer for topic modeling
tokenizer = T5Tokenizer.from_pretrained("UBC-NLP/araT5-base")
model = T5ForConditionalGeneration.from_pretrained("UBC-NLP/araT5-base")
# Emotion classification pipeline
emotion_classifier = pipeline("text-classification", model="aubmindlab/bert-base-arabertv2")
# Function to get embeddings from ARAT5 for topic modeling
def generate_embeddings(texts):
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model.encoder(input_ids=inputs['input_ids'])
embeddings = outputs[0].mean(dim=1).numpy()
return embeddings
# Function to process the CSV or Excel file
def process_file(uploaded_file):
# Determine the file type
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
df = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file format.")
return None
# Validate required columns
required_columns = ['date', 'poem']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing columns: {', '.join(missing_columns)}")
return None
# Process the file
df['date'] = pd.to_datetime(df['date'], errors='coerce')
df = df.dropna(subset=['date'])
df['year'] = df['date'].dt.year
texts = df['poem'].dropna().tolist()
emotions = [emotion_classifier(text)[0]['label'] for text in texts]
df['emotion'] = emotions
embeddings = generate_embeddings(texts)
topic_model = BERTopic()
topics, _ = topic_model.fit_transform(embeddings)
df['topic'] = topics
return df
# Streamlit App
st.title("Arabic Poem Topic Modeling & Emotion Classification")
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
try:
result_df = process_file(uploaded_file)
if result_df is not None:
st.write("Data successfully processed!")
st.write(result_df.head())
except Exception as e:
st.error(f"Error: {e}")