File size: 1,011 Bytes
37c2a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import os
import pandas as pd

class UtilsModels:
    @staticmethod
    def compute_embeddings(sentence_transformer, tokenized_sentences, attention_mask):
        # Flatten the batch and num_sentences dimensions
        batch_size, num_sentences, seq_len = tokenized_sentences.size()
        flat_input_ids = tokenized_sentences.view(-1, seq_len)
        flat_attention_mask = attention_mask.view(-1, seq_len) if attention_mask is not None else None

        # Process sentences through the sentence_transformer
        outputs = sentence_transformer(input_ids=flat_input_ids, attention_mask=flat_attention_mask)
        embeddings = outputs.last_hidden_state

        # Pool the embeddings to get a single vector per sentence (optional)
        # Here, simply taking the mean across the sequence_length dimension
        sentence_embeddings = embeddings.mean(dim=1)

        # Reshape back to [batch_size, num_sentences * 2, embedding_dim]
        return sentence_embeddings.view(batch_size, num_sentences, -1)