File size: 2,599 Bytes
343af91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#EmbeddingGenerator.py

from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.models.bert")
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class EmbeddingGenerator:
    def __init__(self, pavlov_model_name="DeepPavlov/rubert-base-cased", sentence_transformer_model_name="cointegrated/rubert-tiny2"):
        """
        Инициализирует токенизатор и модели для генерации эмбеддингов.

        Args:
            pavlov_model_name (str): Название модели для загрузки Pavlov модели.
            sentence_transformer_model_name (str): Название модели SentenceTransformer для генерации эмбеддингов.
        """
        self.pavlov_tokenizer = AutoTokenizer.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True)
        self.pavlov_model = AutoModel.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True)
        self.sentence_transformer_model = SentenceTransformer(sentence_transformer_model_name)

    def generate_embeddings(self, texts, method="pavlov"):
        """
        Генерирует эмбеддинги для списка текстов с использованием выбранного метода.

        Args:
            texts (list of str): Список текстов для генерации эмбеддингов.
            method (str): Метод генерации эмбеддингов: "pavlov" или "rubert_tiny2".

        Returns:
            np.ndarray: Эмбеддинги текстов.
        """
        if method == "pavlov":
            # Генерация эмбеддингов с использованием Pavlov модели
            inputs = self.pavlov_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
            with torch.no_grad():
                outputs = self.pavlov_model(**inputs)
            # Mean pooling
            embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        elif method == "rubert_tiny2":
            # Генерация эмбеддингов с использованием SentenceTransformer
            embeddings = self.sentence_transformer_model.encode(texts, show_progress_bar=False)
        else:
            raise ValueError("Unsupported method. Choose 'pavlov' or 'rubert_tiny2'.")

        return embeddings