Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -235,34 +235,32 @@ def classify_emotion(text, classifier):
|
|
235 |
|
236 |
def get_embedding_for_text(text, tokenizer, model):
|
237 |
"""Get embedding for complete text."""
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
|
242 |
-
for
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
if len(tokens) >= 512:
|
247 |
-
if current_text:
|
248 |
-
chunks.append(current_text)
|
249 |
-
current_text = word
|
250 |
-
else:
|
251 |
-
current_text = test_text
|
252 |
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
|
|
256 |
chunk_embeddings = []
|
257 |
for chunk in chunks:
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
truncation=True
|
264 |
-
)
|
265 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
266 |
|
267 |
with torch.no_grad():
|
268 |
outputs = model(**inputs)[0]
|
|
|
235 |
|
236 |
def get_embedding_for_text(text, tokenizer, model):
|
237 |
"""Get embedding for complete text."""
|
238 |
+
# First encode the full text to get actual tokens
|
239 |
+
encoded = tokenizer(text, return_tensors="pt", add_special_tokens=False)
|
240 |
+
all_tokens = encoded['input_ids'][0]
|
241 |
|
242 |
+
# Split into chunks of 510 tokens to leave room for [CLS] and [SEP]
|
243 |
+
chunk_size = 510
|
244 |
+
chunks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
+
for i in range(0, len(all_tokens), chunk_size):
|
247 |
+
chunk_tokens = all_tokens[i:i + chunk_size]
|
248 |
+
# Add [CLS] and [SEP] tokens
|
249 |
+
chunk_tokens = torch.cat([
|
250 |
+
torch.tensor([tokenizer.cls_token_id]),
|
251 |
+
chunk_tokens,
|
252 |
+
torch.tensor([tokenizer.sep_token_id])
|
253 |
+
])
|
254 |
+
chunks.append(chunk_tokens)
|
255 |
|
256 |
+
# Get embeddings for each chunk
|
257 |
chunk_embeddings = []
|
258 |
for chunk in chunks:
|
259 |
+
# Create proper input format
|
260 |
+
inputs = {
|
261 |
+
'input_ids': chunk.unsqueeze(0).to(model.device),
|
262 |
+
'attention_mask': torch.ones_like(chunk.unsqueeze(0)).to(model.device)
|
263 |
+
}
|
|
|
|
|
|
|
264 |
|
265 |
with torch.no_grad():
|
266 |
outputs = model(**inputs)[0]
|