kambris commited on
Commit
6bd6b44
·
verified ·
1 Parent(s): 1c5ddd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -36
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
4
  from bertopic import BERTopic
5
  import torch
 
6
  from collections import Counter
7
 
8
  # Load AraBERT tokenizer and model for embeddings
@@ -15,27 +16,41 @@ emotion_classifier = pipeline("text-classification", model=emotion_model, tokeni
15
 
16
  # Function to generate embeddings using AraBERT
17
  def generate_embeddings(texts):
18
- # Tokenize all the texts (poems)
19
- inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=False, max_length=512)
20
-
21
- chunked_inputs = []
22
- for input_ids in inputs['input_ids']:
23
- # Split each long sequence into chunks of max 512 tokens
24
- chunks = [input_ids[i:i + 512] for i in range(0, len(input_ids), 512)]
25
- chunked_inputs.extend(chunks)
26
-
27
- # Process each chunk and get embeddings
28
- embeddings = []
29
- for chunk in chunked_inputs:
30
- input_tensor = torch.tensor(chunk).unsqueeze(0) # Add batch dimension
31
  with torch.no_grad():
32
- outputs = bert_model(input_tensor)
33
- chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
34
- embeddings.append(chunk_embedding)
35
-
36
- # Combine all embeddings (you can take the average of embeddings for each poem)
37
- final_embeddings = sum(embeddings) / len(embeddings)
38
- return final_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Function to process the uploaded file and summarize by country
41
  def process_and_summarize(uploaded_file, top_n=50):
@@ -59,34 +74,41 @@ def process_and_summarize(uploaded_file, top_n=50):
59
  df['country'] = df['country'].str.strip()
60
  df = df.dropna(subset=['country', 'poem'])
61
 
 
 
 
62
  # Group by country
63
  summaries = []
64
- topic_model = BERTopic()
65
  for country, group in df.groupby('country'):
66
  st.info(f"Processing poems for {country}...")
67
 
68
- # Combine all poems for the country
69
  texts = group['poem'].dropna().tolist()
70
 
71
  # Classify emotions
72
  st.info(f"Classifying emotions for {country}...")
73
- emotions = [emotion_classifier(text)[0]['label'] for text in texts]
74
 
75
  # Generate embeddings and fit topic model
76
  st.info(f"Generating embeddings and topics for {country}...")
77
  embeddings = generate_embeddings(texts)
78
- topics, _ = topic_model.fit_transform(embeddings)
79
-
80
- # Aggregate topics and emotions
81
- top_topics = Counter(topics).most_common(top_n)
82
- top_emotions = Counter(emotions).most_common(top_n)
83
-
84
- summaries.append({
85
- 'country': country,
86
- 'total_poems': len(texts),
87
- 'top_topics': top_topics,
88
- 'top_emotions': top_emotions
89
- })
 
 
 
 
 
90
 
91
  return summaries, topic_model
92
 
@@ -117,4 +139,4 @@ if uploaded_file is not None:
117
  st.write("### Global Topic Information:")
118
  st.write(topic_model.get_topic_info())
119
  except Exception as e:
120
- st.error(f"Error: {e}")
 
3
  from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
4
  from bertopic import BERTopic
5
  import torch
6
+ import numpy as np
7
  from collections import Counter
8
 
9
  # Load AraBERT tokenizer and model for embeddings
 
16
 
17
  # Function to generate embeddings using AraBERT
18
  def generate_embeddings(texts):
19
+ all_embeddings = []
20
+
21
+ for text in texts:
22
+ # Tokenize with truncation to handle long sequences
23
+ inputs = bert_tokenizer(
24
+ text,
25
+ return_tensors="pt",
26
+ padding=True,
27
+ truncation=True,
28
+ max_length=512
29
+ )
30
+
31
+ # Generate embeddings
32
  with torch.no_grad():
33
+ outputs = bert_model(**inputs)
34
+
35
+ # Get the mean of the last hidden state as the embedding
36
+ embedding = outputs.last_hidden_state.mean(dim=1).numpy()
37
+ all_embeddings.append(embedding[0]) # Remove batch dimension
38
+
39
+ return np.array(all_embeddings)
40
+
41
+ # Function to perform emotion classification with proper truncation
42
+ def classify_emotions(texts):
43
+ emotions = []
44
+ for text in texts:
45
+ # Process text in chunks if it's too long
46
+ if len(bert_tokenizer.encode(text)) > 512:
47
+ chunks = [text[i:i + 512] for i in range(0, len(text), 512)]
48
+ # Take the emotion of the first chunk (usually contains the most relevant information)
49
+ emotion = emotion_classifier(chunks[0])[0]['label']
50
+ else:
51
+ emotion = emotion_classifier(text)[0]['label']
52
+ emotions.append(emotion)
53
+ return emotions
54
 
55
  # Function to process the uploaded file and summarize by country
56
  def process_and_summarize(uploaded_file, top_n=50):
 
74
  df['country'] = df['country'].str.strip()
75
  df = df.dropna(subset=['country', 'poem'])
76
 
77
+ # Initialize BERTopic
78
+ topic_model = BERTopic(language="arabic")
79
+
80
  # Group by country
81
  summaries = []
 
82
  for country, group in df.groupby('country'):
83
  st.info(f"Processing poems for {country}...")
84
 
85
+ # Get texts for this country
86
  texts = group['poem'].dropna().tolist()
87
 
88
  # Classify emotions
89
  st.info(f"Classifying emotions for {country}...")
90
+ emotions = classify_emotions(texts)
91
 
92
  # Generate embeddings and fit topic model
93
  st.info(f"Generating embeddings and topics for {country}...")
94
  embeddings = generate_embeddings(texts)
95
+
96
+ try:
97
+ topics, _ = topic_model.fit_transform(texts, embeddings)
98
+
99
+ # Aggregate topics and emotions
100
+ top_topics = Counter(topics).most_common(top_n)
101
+ top_emotions = Counter(emotions).most_common(top_n)
102
+
103
+ summaries.append({
104
+ 'country': country,
105
+ 'total_poems': len(texts),
106
+ 'top_topics': top_topics,
107
+ 'top_emotions': top_emotions
108
+ })
109
+ except Exception as e:
110
+ st.warning(f"Could not generate topics for {country}: {str(e)}")
111
+ continue
112
 
113
  return summaries, topic_model
114
 
 
139
  st.write("### Global Topic Information:")
140
  st.write(topic_model.get_topic_info())
141
  except Exception as e:
142
+ st.error(f"Error: {e}")