import os import pandas as pd class UtilsModels: @staticmethod def compute_embeddings(sentence_transformer, tokenized_sentences, attention_mask): # Flatten the batch and num_sentences dimensions batch_size, num_sentences, seq_len = tokenized_sentences.size() flat_input_ids = tokenized_sentences.view(-1, seq_len) flat_attention_mask = attention_mask.view(-1, seq_len) if attention_mask is not None else None # Process sentences through the sentence_transformer outputs = sentence_transformer(input_ids=flat_input_ids, attention_mask=flat_attention_mask) embeddings = outputs.last_hidden_state # Pool the embeddings to get a single vector per sentence (optional) # Here, simply taking the mean across the sequence_length dimension sentence_embeddings = embeddings.mean(dim=1) # Reshape back to [batch_size, num_sentences * 2, embedding_dim] return sentence_embeddings.view(batch_size, num_sentences, -1)