bibliotecadebabel
first commit
37c2a8d
raw
history blame
1.01 kB
import os
import pandas as pd
class UtilsModels:
@staticmethod
def compute_embeddings(sentence_transformer, tokenized_sentences, attention_mask):
# Flatten the batch and num_sentences dimensions
batch_size, num_sentences, seq_len = tokenized_sentences.size()
flat_input_ids = tokenized_sentences.view(-1, seq_len)
flat_attention_mask = attention_mask.view(-1, seq_len) if attention_mask is not None else None
# Process sentences through the sentence_transformer
outputs = sentence_transformer(input_ids=flat_input_ids, attention_mask=flat_attention_mask)
embeddings = outputs.last_hidden_state
# Pool the embeddings to get a single vector per sentence (optional)
# Here, simply taking the mean across the sequence_length dimension
sentence_embeddings = embeddings.mean(dim=1)
# Reshape back to [batch_size, num_sentences * 2, embedding_dim]
return sentence_embeddings.view(batch_size, num_sentences, -1)