test / TextAugmentation.py
AlexandraGulamova's picture
test
343af91
#TextAugmentation.py
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class TextAugmentation:
def __init__(self,
paraphrase_model_name="cointegrated/rut5-base-paraphraser",
ru_en_model_name="Helsinki-NLP/opus-mt-ru-en",
en_ru_model_name="Helsinki-NLP/opus-mt-en-ru"):
# Инициализация модели для перефразирования
self.paraphrase_tokenizer = T5Tokenizer.from_pretrained(paraphrase_model_name, legacy=False)
self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(paraphrase_model_name)
# Инициализация моделей для обратного перевода
self.ru_en_tokenizer = MarianTokenizer.from_pretrained(ru_en_model_name)
self.ru_en_model = MarianMTModel.from_pretrained(ru_en_model_name)
self.en_ru_tokenizer = MarianTokenizer.from_pretrained(en_ru_model_name)
self.en_ru_model = MarianMTModel.from_pretrained(en_ru_model_name)
def paraphrase(self, text, num_return_sequences=1):
"""
Перефразирование текста с использованием модели.
Args:
text (str): Исходный текст для перефразирования.
num_return_sequences (int): Количество вариантов перефразирования.
Returns:
list[str]: Список вариантов перефразирования текста.
"""
inputs = self.paraphrase_tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
outputs = self.paraphrase_model.generate(
**inputs,
max_length=128,
num_return_sequences=num_return_sequences,
do_sample=True,
temperature=1.2,
top_k=50,
top_p=0.90
)
return [self.paraphrase_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
def back_translate(self, text):
"""
Выполняет обратный перевод текста: русский -> английский -> русский.
Args:
text (str): Исходный текст для обратного перевода.
Returns:
str: Текст после обратного перевода.
"""
# Перевод с русского на английский
inputs = self.ru_en_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = self.ru_en_model.generate(**inputs)
translated_text = self.ru_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Перевод с английского обратно на русский
inputs = self.en_ru_tokenizer(translated_text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = self.en_ru_model.generate(**inputs)
back_translated_text = self.en_ru_tokenizer.decode(outputs[0], skip_special_tokens=True)
return back_translated_text