kambris commited on
Commit
4ec5d16
·
verified ·
1 Parent(s): 6bd6b44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -81
app.py CHANGED
@@ -11,94 +11,87 @@ bert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv2")
11
  bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
12
 
13
  # Load AraBERT model for emotion classification
14
- emotion_model = AutoModelForSequenceClassification.from_pretrained("aubmindlab/bert-base-arabertv2")
15
  emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
16
 
17
- # Function to generate embeddings using AraBERT
18
- def generate_embeddings(texts):
19
- all_embeddings = []
20
-
21
- for text in texts:
22
- # Tokenize with truncation to handle long sequences
23
- inputs = bert_tokenizer(
24
- text,
25
- return_tensors="pt",
26
- padding=True,
27
- truncation=True,
28
- max_length=512
29
- )
30
-
31
- # Generate embeddings
32
- with torch.no_grad():
33
- outputs = bert_model(**inputs)
34
-
35
- # Get the mean of the last hidden state as the embedding
36
- embedding = outputs.last_hidden_state.mean(dim=1).numpy()
37
- all_embeddings.append(embedding[0]) # Remove batch dimension
38
-
39
- return np.array(all_embeddings)
40
-
41
- # Function to perform emotion classification with proper truncation
42
- def classify_emotions(texts):
43
- emotions = []
44
- for text in texts:
45
- # Process text in chunks if it's too long
46
- if len(bert_tokenizer.encode(text)) > 512:
47
- chunks = [text[i:i + 512] for i in range(0, len(text), 512)]
48
- # Take the emotion of the first chunk (usually contains the most relevant information)
49
- emotion = emotion_classifier(chunks[0])[0]['label']
50
  else:
51
- emotion = emotion_classifier(text)[0]['label']
52
- emotions.append(emotion)
53
- return emotions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Function to process the uploaded file and summarize by country
56
  def process_and_summarize(uploaded_file, top_n=50):
57
- # Determine the file type
58
- if uploaded_file.name.endswith(".csv"):
59
- df = pd.read_csv(uploaded_file)
60
- elif uploaded_file.name.endswith(".xlsx"):
61
- df = pd.read_excel(uploaded_file)
62
- else:
63
- st.error("Unsupported file format.")
64
- return None, None
65
-
66
- # Validate required columns
67
- required_columns = ['country', 'poem']
68
- missing_columns = [col for col in required_columns if col not in df.columns]
69
- if missing_columns:
70
- st.error(f"Missing columns: {', '.join(missing_columns)}")
71
- return None, None
72
-
73
- # Parse and preprocess the file
74
- df['country'] = df['country'].str.strip()
75
- df = df.dropna(subset=['country', 'poem'])
76
-
77
- # Initialize BERTopic
78
- topic_model = BERTopic(language="arabic")
79
 
80
  # Group by country
81
  summaries = []
82
  for country, group in df.groupby('country'):
83
  st.info(f"Processing poems for {country}...")
84
 
85
- # Get texts for this country
86
  texts = group['poem'].dropna().tolist()
87
-
88
- # Classify emotions
89
- st.info(f"Classifying emotions for {country}...")
90
- emotions = classify_emotions(texts)
91
-
92
- # Generate embeddings and fit topic model
93
- st.info(f"Generating embeddings and topics for {country}...")
94
- embeddings = generate_embeddings(texts)
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
 
 
 
97
  topics, _ = topic_model.fit_transform(texts, embeddings)
98
 
99
- # Aggregate topics and emotions
100
- top_topics = Counter(topics).most_common(top_n)
101
- top_emotions = Counter(emotions).most_common(top_n)
102
 
103
  summaries.append({
104
  'country': country,
@@ -120,7 +113,8 @@ uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
120
 
121
  if uploaded_file is not None:
122
  try:
123
- top_n = st.number_input("Select the number of top topics/emotions to display:", min_value=1, max_value=100, value=50)
 
124
 
125
  summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n)
126
  if summaries is not None:
@@ -130,13 +124,27 @@ if uploaded_file is not None:
130
  for summary in summaries:
131
  st.write(f"### {summary['country']}")
132
  st.write(f"Total Poems: {summary['total_poems']}")
133
- st.write(f"Top {top_n} Topics:")
134
- st.write(summary['top_topics'])
135
- st.write(f"Top {top_n} Emotions:")
136
- st.write(summary['top_emotions'])
137
-
138
- # Display overall topics
 
 
 
 
 
 
139
  st.write("### Global Topic Information:")
140
- st.write(topic_model.get_topic_info())
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
- st.error(f"Error: {e}")
 
11
  bert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv2")
12
 
13
  # Load AraBERT model for emotion classification
14
+ emotion_model = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sentiment")
15
  emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=bert_tokenizer)
16
 
17
+ # Define emotion labels mapping
18
+ EMOTION_LABELS = {
19
+ 'LABEL_0': 'Negative',
20
+ 'LABEL_1': 'Positive',
21
+ 'LABEL_2': 'Neutral'
22
+ }
23
+
24
+ def format_topics(topic_model, topic_counts):
25
+ """Convert topic numbers to readable labels."""
26
+ formatted_topics = []
27
+ for topic_num, count in topic_counts:
28
+ if topic_num == -1:
29
+ topic_label = "Miscellaneous"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  else:
31
+ # Get the top words for this topic
32
+ words = topic_model.get_topic(topic_num)
33
+ # Take the top 3 words to form a topic label
34
+ topic_label = " | ".join([word for word, _ in words[:3]])
35
+
36
+ formatted_topics.append({
37
+ 'topic': topic_label,
38
+ 'count': count
39
+ })
40
+ return formatted_topics
41
+
42
+ def format_emotions(emotion_counts):
43
+ """Convert emotion labels to readable text."""
44
+ formatted_emotions = []
45
+ for label, count in emotion_counts:
46
+ emotion = EMOTION_LABELS.get(label, label)
47
+ formatted_emotions.append({
48
+ 'emotion': emotion,
49
+ 'count': count
50
+ })
51
+ return formatted_emotions
52
+
53
+ # [Previous functions remain the same until process_and_summarize]
54
 
 
55
  def process_and_summarize(uploaded_file, top_n=50):
56
+ # [Previous code remains the same until the summaries loop]
57
+
58
+ # Initialize BERTopic with specific parameters
59
+ topic_model = BERTopic(
60
+ language="arabic",
61
+ calculate_probabilities=True,
62
+ verbose=True
63
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Group by country
66
  summaries = []
67
  for country, group in df.groupby('country'):
68
  st.info(f"Processing poems for {country}...")
69
 
 
70
  texts = group['poem'].dropna().tolist()
71
+ batch_size = 10
72
+ all_emotions = []
73
+ all_embeddings = []
 
 
 
 
 
74
 
75
+ for i in range(0, len(texts), batch_size):
76
+ batch_texts = texts[i:i + batch_size]
77
+
78
+ st.info(f"Generating embeddings for batch {i//batch_size + 1}...")
79
+ batch_embeddings = generate_embeddings(batch_texts)
80
+ all_embeddings.extend(batch_embeddings)
81
+
82
+ st.info(f"Classifying emotions for batch {i//batch_size + 1}...")
83
+ batch_emotions = [classify_emotion(text) for text in batch_texts]
84
+ all_emotions.extend(batch_emotions)
85
+
86
  try:
87
+ embeddings = np.array(all_embeddings)
88
+
89
+ st.info(f"Fitting topic model for {country}...")
90
  topics, _ = topic_model.fit_transform(texts, embeddings)
91
 
92
+ # Format topics and emotions with readable labels
93
+ top_topics = format_topics(topic_model, Counter(topics).most_common(top_n))
94
+ top_emotions = format_emotions(Counter(all_emotions).most_common(top_n))
95
 
96
  summaries.append({
97
  'country': country,
 
113
 
114
  if uploaded_file is not None:
115
  try:
116
+ top_n = st.number_input("Select the number of top topics/emotions to display:",
117
+ min_value=1, max_value=100, value=10)
118
 
119
  summaries, topic_model = process_and_summarize(uploaded_file, top_n=top_n)
120
  if summaries is not None:
 
124
  for summary in summaries:
125
  st.write(f"### {summary['country']}")
126
  st.write(f"Total Poems: {summary['total_poems']}")
127
+
128
+ st.write(f"\nTop {top_n} Topics:")
129
+ for topic in summary['top_topics']:
130
+ st.write(f"• {topic['topic']}: {topic['count']} poems")
131
+
132
+ st.write(f"\nTop {top_n} Emotions:")
133
+ for emotion in summary['top_emotions']:
134
+ st.write(f"• {emotion['emotion']}: {emotion['count']} poems")
135
+
136
+ st.write("---")
137
+
138
+ # Display overall topics in a more readable format
139
  st.write("### Global Topic Information:")
140
+ topic_info = topic_model.get_topic_info()
141
+ for _, row in topic_info.iterrows():
142
+ if row['Topic'] == -1:
143
+ topic_name = "Miscellaneous"
144
+ else:
145
+ words = topic_model.get_topic(row['Topic'])
146
+ topic_name = " | ".join([word for word, _ in words[:3]])
147
+ st.write(f"• Topic {row['Topic']}: {topic_name} ({row['Count']} poems)")
148
+
149
  except Exception as e:
150
+ st.error(f"Error: {str(e)}")