kambris commited on
Commit
7684baa
·
verified ·
1 Parent(s): 5fce9bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -34
app.py CHANGED
@@ -1,26 +1,28 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
4
  from bertopic import BERTopic
5
  import torch
 
6
 
7
- # Initialize ARAT5 model and tokenizer for topic modeling
8
- tokenizer = T5Tokenizer.from_pretrained("UBC-NLP/araT5-base")
9
- model = T5ForConditionalGeneration.from_pretrained("UBC-NLP/araT5-base")
10
 
11
- # Emotion classification pipeline
12
- emotion_classifier = pipeline("text-classification", model="aubmindlab/bert-base-arabertv2")
 
13
 
14
- # Function to get embeddings from ARAT5 for topic modeling
15
  def generate_embeddings(texts):
16
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
17
  with torch.no_grad():
18
- outputs = model.encoder(input_ids=inputs['input_ids'])
19
- embeddings = outputs[0].mean(dim=1).numpy()
20
  return embeddings
21
 
22
- # Function to process the CSV or Excel file
23
- def process_file(uploaded_file):
24
  # Determine the file type
25
  if uploaded_file.name.endswith(".csv"):
26
  df = pd.read_csv(uploaded_file)
@@ -28,39 +30,75 @@ def process_file(uploaded_file):
28
  df = pd.read_excel(uploaded_file)
29
  else:
30
  st.error("Unsupported file format.")
31
- return None
32
 
33
  # Validate required columns
34
- required_columns = ['date', 'poem']
35
  missing_columns = [col for col in required_columns if col not in df.columns]
36
  if missing_columns:
37
  st.error(f"Missing columns: {', '.join(missing_columns)}")
38
- return None
39
-
40
- # Process the file
41
- df['date'] = pd.to_datetime(df['date'], errors='coerce')
42
- df = df.dropna(subset=['date'])
43
- df['year'] = df['date'].dt.year
44
-
45
- texts = df['poem'].dropna().tolist()
46
- emotions = [emotion_classifier(text)[0]['label'] for text in texts]
47
- df['emotion'] = emotions
48
-
49
- embeddings = generate_embeddings(texts)
50
  topic_model = BERTopic()
51
- topics, _ = topic_model.fit_transform(embeddings)
52
- df['topic'] = topics
53
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Streamlit App
 
 
 
 
 
 
 
 
 
56
  st.title("Arabic Poem Topic Modeling & Emotion Classification")
 
 
57
  uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
58
 
59
  if uploaded_file is not None:
60
  try:
61
- result_df = process_file(uploaded_file)
62
- if result_df is not None:
63
- st.write("Data successfully processed!")
64
- st.write(result_df.head())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
  st.error(f"Error: {e}")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
4
  from bertopic import BERTopic
5
  import torch
6
+ from collections import Counter
7
 
8
+ # Load AraBERT tokenizer and model for embeddings
9
+ bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
10
+ bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
11
 
12
+ # Load AraBERT model for emotion classification
13
+ emotion_model = AutoModelForSequenceClassification.from_pretrained("aubmindlab/bert-base-arabertv2")
14
+ emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
15
 
16
+ # Function to generate embeddings using AraBERT
17
  def generate_embeddings(texts):
18
+ inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
19
  with torch.no_grad():
20
+ outputs = bert_model(**inputs)
21
+ embeddings = outputs.last_hidden_state.mean(dim=1).numpy()
22
  return embeddings
23
 
24
+ # Function to process the uploaded file and summarize by country
25
+ def process_and_summarize(uploaded_file, top_n=50):
26
  # Determine the file type
27
  if uploaded_file.name.endswith(".csv"):
28
  df = pd.read_csv(uploaded_file)
 
30
  df = pd.read_excel(uploaded_file)
31
  else:
32
  st.error("Unsupported file format.")
33
+ return None, None
34
 
35
  # Validate required columns
36
+ required_columns = ['country', 'poem']
37
  missing_columns = [col for col in required_columns if col not in df.columns]
38
  if missing_columns:
39
  st.error(f"Missing columns: {', '.join(missing_columns)}")
40
+ return None, None
41
+
42
+ # Parse and preprocess the file
43
+ df['country'] = df['country'].str.strip()
44
+ df = df.dropna(subset=['country', 'poem'])
45
+
46
+ # Group by country
47
+ summaries = []
 
 
 
 
48
  topic_model = BERTopic()
49
+ for country, group in df.groupby('country'):
50
+ st.info(f"Processing poems for {country}...")
51
+
52
+ # Combine all poems for the country
53
+ texts = group['poem'].dropna().tolist()
54
+
55
+ # Classify emotions
56
+ st.info(f"Classifying emotions for {country}...")
57
+ emotions = [emotion_classifier(text)[0]['label'] for text in texts]
58
+
59
+ # Generate embeddings and fit topic model
60
+ st.info(f"Generating embeddings and topics for {country}...")
61
+ embeddings = generate_embeddings(texts)
62
+ topics, _ = topic_model.fit_transform(embeddings)
63
+
64
+ # Aggregate topics and emotions
65
+ top_topics = Counter(topics).most_common(top_n)
66
+ top_emotions = Counter(emotions).most_common(top_n)
67
 
68
+ summaries.append({
69
+ 'country': country,
70
+ 'total_poems': len(texts),
71
+ 'top_topics': top_topics,
72
+ 'top_emotions': top_emotions
73
+ })
74
+
75
+ return summaries, topic_model
76
+
77
+ # Streamlit App Interface
78
  st.title("Arabic Poem Topic Modeling & Emotion Classification")
79
+ st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.")
80
+
81
  uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
82
 
83
  if uploaded_file is not None:
84
  try:
85
+ top_n = st.number_input("Select the number of top topics/emotions to display:", min_value=1, max_value=100, value=50)
86
+
87
+ summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n)
88
+ if summaries is not None:
89
+ st.success("Data successfully processed!")
90
+
91
+ # Display summary for each country
92
+ for summary in summaries:
93
+ st.write(f"### {summary['country']}")
94
+ st.write(f"Total Poems: {summary['total_poems']}")
95
+ st.write(f"Top {top_n} Topics:")
96
+ st.write(summary['top_topics'])
97
+ st.write(f"Top {top_n} Emotions:")
98
+ st.write(summary['top_emotions'])
99
+
100
+ # Display overall topics
101
+ st.write("### Global Topic Information:")
102
+ st.write(topic_model.get_topic_info())
103
  except Exception as e:
104
  st.error(f"Error: {e}")