kambris commited on
Commit
7ad0bec
·
verified ·
1 Parent(s): e5d60e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -36
app.py CHANGED
@@ -234,51 +234,42 @@ def classify_emotion(text, classifier):
234
  return final_emotion
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
- """Get embedding for complete text while preserving all content."""
238
- # Pre-tokenize to get exact chunks
239
- encoded = tokenizer.encode_plus(
240
- text,
241
- add_special_tokens=True,
242
- return_tensors="pt",
243
- return_attention_mask=True,
244
- return_token_type_ids=True
245
- )
246
 
247
- # Get total length
248
- total_length = encoded['input_ids'].size(1)
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- # Process in chunks of 512 tokens
251
  chunk_embeddings = []
252
- for i in range(0, total_length, 512):
253
- # Extract chunk
254
- chunk_dict = {
255
- 'input_ids': encoded['input_ids'][:, i:i + 512],
256
- 'attention_mask': encoded['attention_mask'][:, i:i + 512],
257
- 'token_type_ids': encoded['token_type_ids'][:, i:i + 512]
258
- }
259
-
260
- # Pad if necessary
261
- if chunk_dict['input_ids'].size(1) < 512:
262
- pad_length = 512 - chunk_dict['input_ids'].size(1)
263
- for key in chunk_dict:
264
- chunk_dict[key] = torch.nn.functional.pad(
265
- chunk_dict[key],
266
- (0, pad_length),
267
- 'constant',
268
- 0
269
- )
270
-
271
- # Move to device
272
- chunk_dict = {k: v.to(model.device) for k, v in chunk_dict.items()}
273
 
274
- # Get embeddings
275
  with torch.no_grad():
276
- outputs = model(**chunk_dict)[0]
277
  embedding = outputs[:, 0, :].cpu().numpy()
278
  chunk_embeddings.append(embedding[0])
279
 
280
  if chunk_embeddings:
281
- # Average the embeddings
282
  return np.mean(chunk_embeddings, axis=0)
283
  return np.zeros(model.config.hidden_size)
284
 
 
234
  return final_emotion
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
+ """Get embedding for complete text."""
238
+ chunks = []
239
+ current_text = ""
240
+ words = text.split()
 
 
 
 
 
241
 
242
+ for word in words:
243
+ test_text = current_text + " " + word if current_text else word
244
+ tokens = tokenizer.encode(test_text)
245
+
246
+ if len(tokens) >= 512:
247
+ if current_text:
248
+ chunks.append(current_text)
249
+ current_text = word
250
+ else:
251
+ current_text = test_text
252
+
253
+ if current_text:
254
+ chunks.append(current_text)
255
 
 
256
  chunk_embeddings = []
257
+ for chunk in chunks:
258
+ inputs = tokenizer(
259
+ chunk,
260
+ return_tensors="pt",
261
+ padding=True,
262
+ max_length=512,
263
+ truncation=True
264
+ )
265
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
266
 
 
267
  with torch.no_grad():
268
+ outputs = model(**inputs)[0]
269
  embedding = outputs[:, 0, :].cpu().numpy()
270
  chunk_embeddings.append(embedding[0])
271
 
272
  if chunk_embeddings:
 
273
  return np.mean(chunk_embeddings, axis=0)
274
  return np.zeros(model.config.hidden_size)
275