#TextAugmentation.py from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer import torch import os os.environ["TOKENIZERS_PARALLELISM"] = "false" class TextAugmentation: def __init__(self, paraphrase_model_name="cointegrated/rut5-base-paraphraser", ru_en_model_name="Helsinki-NLP/opus-mt-ru-en", en_ru_model_name="Helsinki-NLP/opus-mt-en-ru"): # Инициализация модели для перефразирования self.paraphrase_tokenizer = T5Tokenizer.from_pretrained(paraphrase_model_name, legacy=False) self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(paraphrase_model_name) # Инициализация моделей для обратного перевода self.ru_en_tokenizer = MarianTokenizer.from_pretrained(ru_en_model_name) self.ru_en_model = MarianMTModel.from_pretrained(ru_en_model_name) self.en_ru_tokenizer = MarianTokenizer.from_pretrained(en_ru_model_name) self.en_ru_model = MarianMTModel.from_pretrained(en_ru_model_name) def paraphrase(self, text, num_return_sequences=1): """ Перефразирование текста с использованием модели. Args: text (str): Исходный текст для перефразирования. num_return_sequences (int): Количество вариантов перефразирования. Returns: list[str]: Список вариантов перефразирования текста. """ inputs = self.paraphrase_tokenizer([text], max_length=512, truncation=True, return_tensors="pt") outputs = self.paraphrase_model.generate( **inputs, max_length=128, num_return_sequences=num_return_sequences, do_sample=True, temperature=1.2, top_k=50, top_p=0.90 ) return [self.paraphrase_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] def back_translate(self, text): """ Выполняет обратный перевод текста: русский -> английский -> русский. Args: text (str): Исходный текст для обратного перевода. Returns: str: Текст после обратного перевода. """ # Перевод с русского на английский inputs = self.ru_en_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = self.ru_en_model.generate(**inputs) translated_text = self.ru_en_tokenizer.decode(outputs[0], skip_special_tokens=True) # Перевод с английского обратно на русский inputs = self.en_ru_tokenizer(translated_text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = self.en_ru_model.generate(**inputs) back_translated_text = self.en_ru_tokenizer.decode(outputs[0], skip_special_tokens=True) return back_translated_text