kambris commited on
Commit
0eea166
·
verified ·
1 Parent(s): ecc1b19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -33
app.py CHANGED
@@ -234,45 +234,31 @@ def classify_emotion(text, classifier):
234
  return final_emotion
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
- """Get embedding for complete text by processing in chunks."""
238
- # Tokenize the entire text
239
- tokens = tokenizer.tokenize(text)
240
-
241
- # Process in chunks of 510 tokens (512 - 2 special tokens)
242
- chunk_size = 510
243
  chunk_embeddings = []
244
 
245
- for i in range(0, len(tokens), chunk_size):
246
- # Take a chunk of tokens
247
- chunk_tokens = tokens[i:i + chunk_size]
248
-
249
- # Add special tokens
250
- chunk_tokens = ['[CLS]'] + chunk_tokens + ['[SEP]']
251
-
252
- # Convert to input IDs
253
- input_ids = tokenizer.convert_tokens_to_ids(chunk_tokens)
254
-
255
- # Pad to 512 tokens
256
- input_ids += [tokenizer.pad_token_id] * (512 - len(input_ids))
257
-
258
- # Create attention mask
259
- attention_mask = [1] * len(chunk_tokens) + [0] * (512 - len(chunk_tokens))
260
-
261
- # Convert to tensors
262
- input_ids = torch.tensor([input_ids])
263
- attention_mask = torch.tensor([attention_mask])
264
 
265
- # Get embedding for this chunk
266
  with torch.no_grad():
267
- outputs = model(input_ids, attention_mask=attention_mask)
268
- chunk_embedding = outputs[0][:, 0, :].cpu().numpy()
269
- chunk_embeddings.append(chunk_embedding[0])
 
270
 
271
- # Average embeddings from all chunks
272
  if chunk_embeddings:
273
- return np.mean(chunk_embeddings, axis=0)
274
-
275
- # Fallback if no embeddings could be generated
 
276
  return np.zeros(model.config.hidden_size)
277
 
278
  def format_topics(topic_model, topic_counts):
 
234
  return final_emotion
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
+ """Get embedding for complete text."""
238
+ chunks = split_text(text)
 
 
 
 
239
  chunk_embeddings = []
240
 
241
+ for chunk in chunks:
242
+ inputs = tokenizer(
243
+ chunk,
244
+ return_tensors="pt",
245
+ padding=True,
246
+ truncation=True,
247
+ max_length=512
248
+ )
249
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
250
 
 
251
  with torch.no_grad():
252
+ # Access the first element of the tuple directly
253
+ outputs = model(**inputs)
254
+ embedding = outputs[0][:, 0, :].cpu().numpy()
255
+ chunk_embeddings.append(embedding[0])
256
 
 
257
  if chunk_embeddings:
258
+ weights = np.array([len(chunk.split()) for chunk in chunks])
259
+ weights = weights / weights.sum()
260
+ weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights)
261
+ return weighted_embedding
262
  return np.zeros(model.config.hidden_size)
263
 
264
  def format_topics(topic_model, topic_counts):