import streamlit as st import transformers from transformers import AutoTokenizer, AutoModelWithLMHead import torch torch.manual_seed(0) model_name = "orzhan/rut5-base-detox-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) @st.cache def load_model(model_name): model = AutoModelWithLMHead.from_pretrained(model_name) return model model = load_model(model_name) def infer(input_ids): output_sequences = model.generate( input_ids=input_ids, max_length=40, do_sample=False, num_return_sequences=1, num_beams=8, length_penalty=1.0, no_repeat_ngram_size=3, # do_sample=True, top_p=0.9, temperature=0.9 ) return output_sequences default_value = "всегда ненавидел этих ублюдочных тварей" examples = [default_value, "убила бы этих выродков и их родителей.", "пошли вы сука все на хуй со своим коронавирусом...", "Перед барином выслуживаешься холоп?", "просто , зарплату шакалов отрабатывают , а больше на крысенышей похожи...", "--- свой текст ---"] #prompts st.title("Демо детоксификации на ruT5") sent = st.selectbox("Пример", examples) if sent == "--- свой текст ---": sent = st.text_area("Исходный текст", default_value) #if custom_sent == default_value: # custom_sent = sent st.button('Сделать нетоксичным') encoded_prompt = tokenizer.encode(sent, add_special_tokens=True, return_tensors="pt") if encoded_prompt.size()[-1] == 0: input_ids = None else: input_ids = encoded_prompt output_sequences = infer(input_ids) for generated_sequence_idx, generated_sequence in enumerate(output_sequences): print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") generated_sequences = generated_sequence.tolist() # Decode text text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) # Remove all text after the stop token #text = text[: text.find(args.stop_token) if args.stop_token else None] # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing total_sequence = ( text ) generated_sequences.append(total_sequence) print(total_sequence) st.write("Преобразованный текст: ") st.write(generated_sequences[-1])