SoLProject / app.py
kambris's picture
Update app.py
6f973fa verified
raw
history blame
3.31 kB
import streamlit as st
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration, BertTokenizer, BertModel
from bertopic import BERTopic
import torch
import numpy as np
# Initialize ARAT5 model and tokenizer for topic modeling
tokenizer = T5Tokenizer.from_pretrained("UBC-NLP/araT5-base")
model = T5ForConditionalGeneration.from_pretrained("UBC-NLP/araT5-base")
# Initialize BERT tokenizer and model for feature extraction
bert_tokenizer = BertTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
bert_model = BertModel.from_pretrained("aubmindlab/bert-base-arabertv2")
# Function to get embeddings from ARAT5 for topic modeling
def generate_embeddings(texts):
embeddings = []
for text in texts:
# Tokenize the text with truncation set to False
# We are using the BertTokenizer directly without using the pipeline
tokens = bert_tokenizer.encode(text, add_special_tokens=True, truncation=False, padding=False)
# Split the tokens into chunks of size 512 (maximum length)
chunked_texts = [tokens[i:i + 512] for i in range(0, len(tokens), 512)]
poem_embeddings = []
for chunk in chunked_texts:
# Convert the chunk to a tensor and prepare the input for BERT model
inputs = torch.tensor(chunk).unsqueeze(0) # Adding batch dimension
with torch.no_grad():
outputs = bert_model(inputs)
# Get the embeddings from the last hidden state (mean of all token embeddings)
chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
poem_embeddings.append(chunk_embedding)
# Average the embeddings of all chunks (optional, can also concatenate them)
final_embedding = np.mean(np.array(poem_embeddings), axis=0)
embeddings.append(final_embedding)
return embeddings
# Function to process the CSV or Excel file
def process_file(uploaded_file):
# Determine the file type
if uploaded_file.name.endswith(".csv"):
df = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".xlsx"):
df = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file format.")
return None
# Validate required columns
required_columns = ['country', 'poem']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing columns: {', '.join(missing_columns)}")
return None
# Process the file
df = df.dropna(subset=['country', 'poem'])
texts = df['poem'].dropna().tolist()
# Generate embeddings for all poems
embeddings = generate_embeddings(texts)
# Perform topic modeling with BERTopic
topic_model = BERTopic()
topics, _ = topic_model.fit_transform(embeddings)
df['topic'] = topics
return df
# Streamlit App
st.title("Arabic Poem Topic Modeling & Emotion Classification")
uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
if uploaded_file is not None:
try:
result_df = process_file(uploaded_file)
if result_df is not None:
st.write("Data successfully processed!")
st.write(result_df.head())
except Exception as e:
st.error(f"Error: {e}")