kambris commited on
Commit
1c5ddd8
·
verified ·
1 Parent(s): 8c80330

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -15,16 +15,12 @@ emotion_classifier = pipeline("text-classification", model=emotion_model, tokeni
15
 
16
  # Function to generate embeddings using AraBERT
17
  def generate_embeddings(texts):
18
- # Ensure texts is a list of strings
19
- if isinstance(texts, str): # If single string, convert to list
20
- texts = [texts]
21
-
22
- # Tokenize the list of strings (ensure all are strings)
23
  inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=False, max_length=512)
24
 
25
- # Split large sequences into chunks of size 512
26
  chunked_inputs = []
27
  for input_ids in inputs['input_ids']:
 
28
  chunks = [input_ids[i:i + 512] for i in range(0, len(input_ids), 512)]
29
  chunked_inputs.extend(chunks)
30
 
@@ -37,8 +33,9 @@ def generate_embeddings(texts):
37
  chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
38
  embeddings.append(chunk_embedding)
39
 
40
- # Return the embeddings averaged across chunks
41
- return embeddings
 
42
 
43
  # Function to process the uploaded file and summarize by country
44
  def process_and_summarize(uploaded_file, top_n=50):
 
15
 
16
  # Function to generate embeddings using AraBERT
17
  def generate_embeddings(texts):
18
+ # Tokenize all the texts (poems)
 
 
 
 
19
  inputs = bert_tokenizer(texts, return_tensors="pt", padding=True, truncation=False, max_length=512)
20
 
 
21
  chunked_inputs = []
22
  for input_ids in inputs['input_ids']:
23
+ # Split each long sequence into chunks of max 512 tokens
24
  chunks = [input_ids[i:i + 512] for i in range(0, len(input_ids), 512)]
25
  chunked_inputs.extend(chunks)
26
 
 
33
  chunk_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
34
  embeddings.append(chunk_embedding)
35
 
36
+ # Combine all embeddings (you can take the average of embeddings for each poem)
37
+ final_embeddings = sum(embeddings) / len(embeddings)
38
+ return final_embeddings
39
 
40
  # Function to process the uploaded file and summarize by country
41
  def process_and_summarize(uploaded_file, top_n=50):