File size: 661 Bytes
6f6920c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from sentence_transformers import SentenceTransformer, models
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

word_embedding_model = models.Transformer("BAAI/bge-base-en-v1.5", max_seq_length=512)

word_embedding_model.tokenizer.add_tokens(['[TURN]'], special_tokens=True)
word_embedding_model.tokenizer.truncation_side = 'left'
word_embedding_model.auto_model.resize_token_embeddings(len(word_embedding_model.tokenizer))

pooling_model = models.Pooling(
    word_embedding_model.get_word_embedding_dimension(), pooling_mode="cls"
)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)