kambris commited on
Commit
e5d60e7
·
verified ·
1 Parent(s): 0ba80af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -21
app.py CHANGED
@@ -235,36 +235,51 @@ def classify_emotion(text, classifier):
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
  """Get embedding for complete text while preserving all content."""
238
- # Get exact token counts
239
- tokenized_text = tokenizer.encode(text)
240
- total_tokens = len(tokenized_text)
 
 
 
 
 
241
 
242
- # Create precise chunks of 512 tokens
243
- chunks = []
244
- for i in range(0, total_tokens, 512):
245
- chunk = tokenized_text[i:i + 512]
246
- chunks.append(tokenizer.decode(chunk))
247
 
 
248
  chunk_embeddings = []
249
- for chunk in chunks:
250
- inputs = tokenizer(
251
- chunk,
252
- return_tensors="pt",
253
- padding='max_length',
254
- max_length=512
255
- )
256
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
 
258
  with torch.no_grad():
259
- outputs = model(**inputs)[0]
260
  embedding = outputs[:, 0, :].cpu().numpy()
261
  chunk_embeddings.append(embedding[0])
262
 
263
  if chunk_embeddings:
264
- weights = np.array([len(chunk.split()) for chunk in chunks])
265
- weights = weights / weights.sum()
266
- weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights)
267
- return weighted_embedding
268
  return np.zeros(model.config.hidden_size)
269
 
270
  def format_topics(topic_model, topic_counts):
 
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
  """Get embedding for complete text while preserving all content."""
238
+ # Pre-tokenize to get exact chunks
239
+ encoded = tokenizer.encode_plus(
240
+ text,
241
+ add_special_tokens=True,
242
+ return_tensors="pt",
243
+ return_attention_mask=True,
244
+ return_token_type_ids=True
245
+ )
246
 
247
+ # Get total length
248
+ total_length = encoded['input_ids'].size(1)
 
 
 
249
 
250
+ # Process in chunks of 512 tokens
251
  chunk_embeddings = []
252
+ for i in range(0, total_length, 512):
253
+ # Extract chunk
254
+ chunk_dict = {
255
+ 'input_ids': encoded['input_ids'][:, i:i + 512],
256
+ 'attention_mask': encoded['attention_mask'][:, i:i + 512],
257
+ 'token_type_ids': encoded['token_type_ids'][:, i:i + 512]
258
+ }
259
+
260
+ # Pad if necessary
261
+ if chunk_dict['input_ids'].size(1) < 512:
262
+ pad_length = 512 - chunk_dict['input_ids'].size(1)
263
+ for key in chunk_dict:
264
+ chunk_dict[key] = torch.nn.functional.pad(
265
+ chunk_dict[key],
266
+ (0, pad_length),
267
+ 'constant',
268
+ 0
269
+ )
270
+
271
+ # Move to device
272
+ chunk_dict = {k: v.to(model.device) for k, v in chunk_dict.items()}
273
 
274
+ # Get embeddings
275
  with torch.no_grad():
276
+ outputs = model(**chunk_dict)[0]
277
  embedding = outputs[:, 0, :].cpu().numpy()
278
  chunk_embeddings.append(embedding[0])
279
 
280
  if chunk_embeddings:
281
+ # Average the embeddings
282
+ return np.mean(chunk_embeddings, axis=0)
 
 
283
  return np.zeros(model.config.hidden_size)
284
 
285
  def format_topics(topic_model, topic_counts):