Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -15,16 +15,12 @@ emotion_classifier = pipeline("text-classification", model=emotion_model, tokeni
|
|
15 |
|
16 |
# Function to generate embeddings using AraBERT
|
17 |
def generate_embeddings(texts):
|
18 |
-
#
|
19 |
-
if isinstance(texts, str): # If single string, convert to list
|
20 |
-
texts = [texts]
|
21 |
-
|
22 |
-
# Tokenize the list of strings (ensure all are strings)
|
23 |
inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=False, max_length=512)
|
24 |
|
25 |
-
# Split large sequences into chunks of size 512
|
26 |
chunked_inputs = []
|
27 |
for input_ids in inputs['input_ids']:
|
|
|
28 |
chunks = [input_ids[i:i + 512] for i in range(0, len(input_ids), 512)]
|
29 |
chunked_inputs.extend(chunks)
|
30 |
|
@@ -37,8 +33,9 @@ def generate_embeddings(texts):
|
|
37 |
chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
|
38 |
embeddings.append(chunk_embedding)
|
39 |
|
40 |
-
#
|
41 |
-
|
|
|
42 |
|
43 |
# Function to process the uploaded file and summarize by country
|
44 |
def process_and_summarize(uploaded_file, top_n=50):
|
|
|
15 |
|
16 |
# Function to generate embeddings using AraBERT
|
17 |
def generate_embeddings(texts):
|
18 |
+
# Tokenize all the texts (poems)
|
|
|
|
|
|
|
|
|
19 |
inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=False, max_length=512)
|
20 |
|
|
|
21 |
chunked_inputs = []
|
22 |
for input_ids in inputs['input_ids']:
|
23 |
+
# Split each long sequence into chunks of max 512 tokens
|
24 |
chunks = [input_ids[i:i + 512] for i in range(0, len(input_ids), 512)]
|
25 |
chunked_inputs.extend(chunks)
|
26 |
|
|
|
33 |
chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
|
34 |
embeddings.append(chunk_embedding)
|
35 |
|
36 |
+
# Combine all embeddings (you can take the average of embeddings for each poem)
|
37 |
+
final_embeddings = sum(embeddings) / len(embeddings)
|
38 |
+
return final_embeddings
|
39 |
|
40 |
# Function to process the uploaded file and summarize by country
|
41 |
def process_and_summarize(uploaded_file, top_n=50):
|