kambris commited on
Commit
b2576ed
·
verified ·
1 Parent(s): 0156b72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -76
app.py CHANGED
@@ -5,14 +5,31 @@ from bertopic import BERTopic
5
  import torch
6
  import numpy as np
7
  from collections import Counter
 
8
 
9
- # Load AraBERT tokenizer and model for embeddings
10
- bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
11
- bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
 
 
 
12
 
13
- # Load AraBERT model for emotion classification
14
- emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
15
- emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Define emotion labels mapping
18
  EMOTION_LABELS = {
@@ -21,80 +38,67 @@ EMOTION_LABELS = {
21
  'LABEL_2': 'Neutral'
22
  }
23
 
24
- def chunk_long_text(text, max_length=512):
25
- """
26
- Split text into chunks respecting AraBERT's token limit.
27
- Returns both tokenized chunks and decoded text chunks.
28
- """
29
- # Tokenize the entire text
30
- tokens = bert_tokenizer.encode(text, add_special_tokens=False)
31
  chunks = []
32
  text_chunks = []
33
 
34
- # Split into chunks of max_length-2 to account for [CLS] and [SEP]
35
  for i in range(0, len(tokens), max_length-2):
36
  chunk = tokens[i:i + max_length-2]
37
- # Add special tokens
38
- full_chunk = [bert_tokenizer.cls_token_id] + chunk + [bert_tokenizer.sep_token_id]
39
  chunks.append(full_chunk)
40
- # Decode the chunk back to text (without special tokens)
41
- text_chunks.append(bert_tokenizer.decode(chunk))
42
 
43
  return chunks, text_chunks
44
 
45
- def get_embedding_for_text(text):
46
- """Get embedding for a text, handling long sequences by averaging chunk embeddings."""
47
- _, text_chunks = chunk_long_text(text)
48
  chunk_embeddings = []
49
 
50
  for chunk in text_chunks:
51
- # Encode chunk with padding and attention mask
52
- inputs = bert_tokenizer(chunk,
53
- return_tensors="pt",
54
- padding=True,
55
- truncation=True,
56
- max_length=512)
57
- inputs = {k: v.to(bert_model.device) for k, v in inputs.items()}
58
 
59
  with torch.no_grad():
60
- outputs = bert_model(**inputs)
61
 
62
- # Get [CLS] token embedding
63
  embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
64
  chunk_embeddings.append(embedding[0])
65
 
66
- # Average embeddings from all chunks
67
  if chunk_embeddings:
68
  return np.mean(chunk_embeddings, axis=0)
69
- return np.zeros(bert_model.config.hidden_size)
70
 
71
- def generate_embeddings(texts):
72
  """Generate embeddings for a list of texts."""
73
  embeddings = []
74
 
75
  for text in texts:
76
  try:
77
- embedding = get_embedding_for_text(text)
78
  embeddings.append(embedding)
79
  except Exception as e:
80
  st.warning(f"Error processing text: {str(e)}")
81
- embeddings.append(np.zeros(bert_model.config.hidden_size))
82
 
83
  return np.array(embeddings)
84
 
85
- def classify_emotion(text):
86
- """
87
- Classify emotion for a text, handling long sequences by voting among chunks.
88
- """
89
  try:
90
- _, text_chunks = chunk_long_text(text)
91
  chunk_emotions = []
92
 
93
  for chunk in text_chunks:
94
- result = emotion_classifier(chunk, max_length=512, truncation=True)[0]
95
  chunk_emotions.append(result['label'])
96
 
97
- # Use majority voting for final emotion
98
  if chunk_emotions:
99
  final_emotion = Counter(chunk_emotions).most_common(1)[0][0]
100
  return final_emotion
@@ -105,7 +109,7 @@ def classify_emotion(text):
105
  return "unknown"
106
 
107
  def format_topics(topic_model, topic_counts):
108
- """Convert topic numbers to readable labels."""
109
  formatted_topics = []
110
  for topic_num, count in topic_counts:
111
  if topic_num == -1:
@@ -121,7 +125,7 @@ def format_topics(topic_model, topic_counts):
121
  return formatted_topics
122
 
123
  def format_emotions(emotion_counts):
124
- """Convert emotion labels to readable text."""
125
  formatted_emotions = []
126
  for label, count in emotion_counts:
127
  emotion = EMOTION_LABELS.get(label, label)
@@ -131,28 +135,11 @@ def format_emotions(emotion_counts):
131
  })
132
  return formatted_emotions
133
 
134
- def process_and_summarize(uploaded_file, top_n=50):
135
- # Determine the file type
136
- if uploaded_file.name.endswith(".csv"):
137
- df = pd.read_csv(uploaded_file)
138
- elif uploaded_file.name.endswith(".xlsx"):
139
- df = pd.read_excel(uploaded_file)
140
- else:
141
- st.error("Unsupported file format.")
142
- return None, None
143
-
144
- # Validate required columns
145
- required_columns = ['country', 'poem']
146
- missing_columns = [col for col in required_columns if col not in df.columns]
147
- if missing_columns:
148
- st.error(f"Missing columns: {', '.join(missing_columns)}")
149
- return None, None
150
-
151
- # Parse and preprocess the file
152
- df['country'] = df['country'].str.strip()
153
- df = df.dropna(subset=['country', 'poem'])
154
 
155
- # Initialize BERTopic with specific parameters for Arabic
156
  topic_model = BERTopic(
157
  language="arabic",
158
  calculate_probabilities=True,
@@ -161,27 +148,28 @@ def process_and_summarize(uploaded_file, top_n=50):
161
  )
162
 
163
  # Group by country
164
- summaries = []
165
  for country, group in df.groupby('country'):
166
- st.info(f"Processing poems for {country}...")
167
-
 
168
  texts = group['poem'].dropna().tolist()
169
  batch_size = 10
170
  all_emotions = []
171
 
172
- # Generate embeddings for all texts
173
- st.info("Generating embeddings...")
174
- embeddings = generate_embeddings(texts)
175
 
176
- # Process emotions in batches
177
  for i in range(0, len(texts), batch_size):
178
  batch_texts = texts[i:i + batch_size]
179
- st.info(f"Classifying emotions for batch {i//batch_size + 1}...")
180
- batch_emotions = [classify_emotion(text) for text in batch_texts]
181
  all_emotions.extend(batch_emotions)
 
182
 
183
  try:
184
- st.info(f"Fitting topic model for {country}...")
185
  topics, _ = topic_model.fit_transform(texts, embeddings)
186
 
187
  # Format results
@@ -194,10 +182,88 @@ def process_and_summarize(uploaded_file, top_n=50):
194
  'top_topics': top_topics,
195
  'top_emotions': top_emotions
196
  })
 
197
  except Exception as e:
198
  st.warning(f"Could not generate topics for {country}: {str(e)}")
199
  continue
200
 
201
  return summaries, topic_model
202
 
203
- # Streamlit interface remains the same...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import numpy as np
7
  from collections import Counter
8
+ import os
9
 
10
+ # Configure page
11
+ st.set_page_config(
12
+ page_title="Arabic Poem Analysis",
13
+ page_icon="📚",
14
+ layout="wide"
15
+ )
16
 
17
+ @st.cache_resource
18
+ def load_models():
19
+ """Load and cache the models to prevent reloading"""
20
+ bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
21
+ bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
22
+ emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
23
+ emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
24
+ return bert_tokenizer, bert_model, emotion_classifier
25
+
26
+ # Load models
27
+ try:
28
+ bert_tokenizer, bert_model, emotion_classifier = load_models()
29
+ st.success("Models loaded successfully!")
30
+ except Exception as e:
31
+ st.error(f"Error loading models: {str(e)}")
32
+ st.stop()
33
 
34
  # Define emotion labels mapping
35
  EMOTION_LABELS = {
 
38
  'LABEL_2': 'Neutral'
39
  }
40
 
41
+ def chunk_long_text(text, tokenizer, max_length=512):
42
+ """Split text into chunks respecting token limit."""
43
+ tokens = tokenizer.encode(text, add_special_tokens=False)
 
 
 
 
44
  chunks = []
45
  text_chunks = []
46
 
 
47
  for i in range(0, len(tokens), max_length-2):
48
  chunk = tokens[i:i + max_length-2]
49
+ full_chunk = [tokenizer.cls_token_id] + chunk + [tokenizer.sep_token_id]
 
50
  chunks.append(full_chunk)
51
+ text_chunks.append(tokenizer.decode(chunk))
 
52
 
53
  return chunks, text_chunks
54
 
55
+ def get_embedding_for_text(text, tokenizer, model):
56
+ """Get embedding for a text, handling long sequences."""
57
+ _, text_chunks = chunk_long_text(text, tokenizer)
58
  chunk_embeddings = []
59
 
60
  for chunk in text_chunks:
61
+ inputs = tokenizer(chunk,
62
+ return_tensors="pt",
63
+ padding=True,
64
+ truncation=True,
65
+ max_length=512)
66
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
67
 
68
  with torch.no_grad():
69
+ outputs = model(**inputs)
70
 
 
71
  embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
72
  chunk_embeddings.append(embedding[0])
73
 
 
74
  if chunk_embeddings:
75
  return np.mean(chunk_embeddings, axis=0)
76
+ return np.zeros(model.config.hidden_size)
77
 
78
+ def generate_embeddings(texts, tokenizer, model):
79
  """Generate embeddings for a list of texts."""
80
  embeddings = []
81
 
82
  for text in texts:
83
  try:
84
+ embedding = get_embedding_for_text(text, tokenizer, model)
85
  embeddings.append(embedding)
86
  except Exception as e:
87
  st.warning(f"Error processing text: {str(e)}")
88
+ embeddings.append(np.zeros(model.config.hidden_size))
89
 
90
  return np.array(embeddings)
91
 
92
+ def classify_emotion(text, tokenizer, classifier):
93
+ """Classify emotion for a text using majority voting."""
 
 
94
  try:
95
+ _, text_chunks = chunk_long_text(text, tokenizer)
96
  chunk_emotions = []
97
 
98
  for chunk in text_chunks:
99
+ result = classifier(chunk, max_length=512, truncation=True)[0]
100
  chunk_emotions.append(result['label'])
101
 
 
102
  if chunk_emotions:
103
  final_emotion = Counter(chunk_emotions).most_common(1)[0][0]
104
  return final_emotion
 
109
  return "unknown"
110
 
111
  def format_topics(topic_model, topic_counts):
112
+ """Format topics for display."""
113
  formatted_topics = []
114
  for topic_num, count in topic_counts:
115
  if topic_num == -1:
 
125
  return formatted_topics
126
 
127
  def format_emotions(emotion_counts):
128
+ """Format emotions for display."""
129
  formatted_emotions = []
130
  for label, count in emotion_counts:
131
  emotion = EMOTION_LABELS.get(label, label)
 
135
  })
136
  return formatted_emotions
137
 
138
+ def process_and_summarize(df, top_n=50):
139
+ """Process the data and generate summaries."""
140
+ summaries = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ # Initialize BERTopic
143
  topic_model = BERTopic(
144
  language="arabic",
145
  calculate_probabilities=True,
 
148
  )
149
 
150
  # Group by country
 
151
  for country, group in df.groupby('country'):
152
+ progress_text = f"Processing poems for {country}..."
153
+ progress_bar = st.progress(0, text=progress_text)
154
+
155
  texts = group['poem'].dropna().tolist()
156
  batch_size = 10
157
  all_emotions = []
158
 
159
+ # Generate embeddings
160
+ embeddings = generate_embeddings(texts, bert_tokenizer, bert_model)
161
+ progress_bar.progress(0.33, text="Generating embeddings...")
162
 
163
+ # Process emotions
164
  for i in range(0, len(texts), batch_size):
165
  batch_texts = texts[i:i + batch_size]
166
+ batch_emotions = [classify_emotion(text, bert_tokenizer, emotion_classifier)
167
+ for text in batch_texts]
168
  all_emotions.extend(batch_emotions)
169
+ progress_bar.progress(0.66, text="Classifying emotions...")
170
 
171
  try:
172
+ # Fit topic model
173
  topics, _ = topic_model.fit_transform(texts, embeddings)
174
 
175
  # Format results
 
182
  'top_topics': top_topics,
183
  'top_emotions': top_emotions
184
  })
185
+ progress_bar.progress(1.0, text="Processing complete!")
186
  except Exception as e:
187
  st.warning(f"Could not generate topics for {country}: {str(e)}")
188
  continue
189
 
190
  return summaries, topic_model
191
 
192
+ # Main app interface
193
+ st.title("📚 Arabic Poem Analysis")
194
+ st.write("Upload a CSV or Excel file containing Arabic poems with columns `country` and `poem`.")
195
+
196
+ # File upload
197
+ uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
198
+
199
+ if uploaded_file is not None:
200
+ try:
201
+ # Read the file
202
+ if uploaded_file.name.endswith('.csv'):
203
+ df = pd.read_csv(uploaded_file)
204
+ else:
205
+ df = pd.read_excel(uploaded_file)
206
+
207
+ # Validate columns
208
+ required_columns = ['country', 'poem']
209
+ if not all(col in df.columns for col in required_columns):
210
+ st.error("File must contain 'country' and 'poem' columns.")
211
+ st.stop()
212
+
213
+ # Clean data
214
+ df['country'] = df['country'].str.strip()
215
+ df = df.dropna(subset=['country', 'poem'])
216
+
217
+ # Process data
218
+ top_n = st.number_input("Number of top topics/emotions to display:",
219
+ min_value=1, max_value=100, value=10)
220
+
221
+ if st.button("Process Data"):
222
+ with st.spinner("Processing your data..."):
223
+ summaries, topic_model = process_and_summarize(df, top_n=top_n)
224
+
225
+ if summaries:
226
+ st.success("Analysis complete!")
227
+
228
+ # Display results in tabs
229
+ tab1, tab2 = st.tabs(["Country Summaries", "Global Topics"])
230
+
231
+ with tab1:
232
+ for summary in summaries:
233
+ with st.expander(f"📍 {summary['country']} ({summary['total_poems']} poems)"):
234
+ col1, col2 = st.columns(2)
235
+
236
+ with col1:
237
+ st.subheader("Top Topics")
238
+ for topic in summary['top_topics']:
239
+ st.write(f"• {topic['topic']}: {topic['count']} poems")
240
+
241
+ with col2:
242
+ st.subheader("Emotions")
243
+ for emotion in summary['top_emotions']:
244
+ st.write(f"• {emotion['emotion']}: {emotion['count']} poems")
245
+
246
+ with tab2:
247
+ st.subheader("Global Topic Distribution")
248
+ topic_info = topic_model.get_topic_info()
249
+ for _, row in topic_info.iterrows():
250
+ if row['Topic'] == -1:
251
+ topic_name = "Miscellaneous"
252
+ else:
253
+ words = topic_model.get_topic(row['Topic'])
254
+ topic_name = " | ".join([word for word, _ in words[:3]])
255
+ st.write(f"• Topic {row['Topic']}: {topic_name} ({row['Count']} poems)")
256
+
257
+ except Exception as e:
258
+ st.error(f"Error processing file: {str(e)}")
259
+ else:
260
+ st.info("👆 Upload a file to get started!")
261
+
262
+ # Example format
263
+ st.write("### Expected File Format:")
264
+ example_df = pd.DataFrame({
265
+ 'country': ['Egypt', 'Saudi Arabia'],
266
+ 'poem': ['قصيدة مصرية', 'قصيدة سعودية']
267
+ })
268
+ st.dataframe(example_df)
269
+