Spaces:
Sleeping
Sleeping
#EmbeddingGenerator.py | |
from transformers import AutoTokenizer, AutoModel | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import numpy as np | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.models.bert") | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
class EmbeddingGenerator: | |
def __init__(self, pavlov_model_name="DeepPavlov/rubert-base-cased", sentence_transformer_model_name="cointegrated/rubert-tiny2"): | |
""" | |
Инициализирует токенизатор и модели для генерации эмбеддингов. | |
Args: | |
pavlov_model_name (str): Название модели для загрузки Pavlov модели. | |
sentence_transformer_model_name (str): Название модели SentenceTransformer для генерации эмбеддингов. | |
""" | |
self.pavlov_tokenizer = AutoTokenizer.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True) | |
self.pavlov_model = AutoModel.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True) | |
self.sentence_transformer_model = SentenceTransformer(sentence_transformer_model_name) | |
def generate_embeddings(self, texts, method="pavlov"): | |
""" | |
Генерирует эмбеддинги для списка текстов с использованием выбранного метода. | |
Args: | |
texts (list of str): Список текстов для генерации эмбеддингов. | |
method (str): Метод генерации эмбеддингов: "pavlov" или "rubert_tiny2". | |
Returns: | |
np.ndarray: Эмбеддинги текстов. | |
""" | |
if method == "pavlov": | |
# Генерация эмбеддингов с использованием Pavlov модели | |
inputs = self.pavlov_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.pavlov_model(**inputs) | |
# Mean pooling | |
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
elif method == "rubert_tiny2": | |
# Генерация эмбеддингов с использованием SentenceTransformer | |
embeddings = self.sentence_transformer_model.encode(texts, show_progress_bar=False) | |
else: | |
raise ValueError("Unsupported method. Choose 'pavlov' or 'rubert_tiny2'.") | |
return embeddings | |