SoLProject / app.py
kambris's picture
Update app.py
631c46c verified
raw
history blame
3.17 kB
import streamlit as st
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
from bertopic import BERTopic
import torch
import numpy as np
# Initialize ARAT5 model and tokenizer for topic modeling
tokenizer = T5Tokenizer.from_pretrained("UBC-NLP/araT5-base")
model = T5ForConditionalGeneration.from_pretrained("UBC-NLP/araT5-base")
# Initialize AraBERT model and tokenizer
bert_tokenizer = pipeline("feature-extraction", model="aubmindlab/bert-base-arabertv2")
# Function to get embeddings from ARAT5 for topic modeling
def generate_embeddings(texts):
embeddings = []
for text in texts:
# Tokenize the text
tokens = bert_tokenizer.tokenizer.encode(text, truncation=False) # Get tokens without truncation
# Split the tokens into chunks of size 512 (maximum length)
chunked_texts = [tokens[i:i + 512] for i in range(0, len(tokens), 512)]
poem_embeddings = []
for chunk in chunked_texts:
# Decode the chunk back into text (optional but useful for debugging)
chunk_text = bert_tokenizer.decode(chunk)
# Process each chunk and get embeddings
inputs = bert_tokenizer(chunk_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = bert_tokenizer(**inputs)
chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
poem_embeddings.append(chunk_embedding)
# Average the embeddings of all chunks (optional, can also concatenate them)
final_embedding = np.mean(np.array(poem_embeddings), axis=0)
embeddings.append(final_embedding)
return embeddings
# Function to process the CSV or Excel file
def process_file(uploaded_file):
# Determine the file type
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
df = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file format.")
return None
# Validate required columns
required_columns = ['country', 'poem']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing columns: {', '.join(missing_columns)}")
return None
# Process the file
df = df.dropna(subset=['country', 'poem'])
texts = df['poem'].dropna().tolist()
# Generate embeddings for all poems
embeddings = generate_embeddings(texts)
# Perform topic modeling with BERTopic
topic_model = BERTopic()
topics, _ = topic_model.fit_transform(embeddings)
df['topic'] = topics
return df
# Streamlit App
st.title("Arabic Poem Topic Modeling & Emotion Classification")
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
try:
result_df = process_file(uploaded_file)
if result_df is not None:
st.write("Data successfully processed!")
st.write(result_df.head())
except Exception as e:
st.error(f"Error: {e}")