import streamlit as st import transformers from transformers import AutoTokenizer, AutoModelWithLMHead model_name = "orzhan/rut5-base-detox" tokenizer = AutoTokenizer.from_pretrained(model_name) @st.cache def load_model(model_name): model = AutoModelWithLMHead.from_pretrained(model_name) return model model = load_model(model_name) def infer(input_ids): output_sequences = model.generate( input_ids=input_ids, max_length=40, # do_sample=False, num_return_sequences=1, num_beams=8, length_penalty=1.5, no_repeat_ngram_size=4, # do_sample=True, top_p=0.9, temperature=0.9 ) return output_sequences default_value = "Еще один петух решил легко заработать на донатах." #prompts st.title("Демо детоксификации на ruT5") sent = st.text_area("Text", default_value) st.button('Сделать нетактичным') encoded_prompt = tokenizer.encode(sent, add_special_tokens=False, return_tensors="pt") if encoded_prompt.size()[-1] == 0: input_ids = None else: input_ids = encoded_prompt output_sequences = infer(input_ids) for generated_sequence_idx, generated_sequence in enumerate(output_sequences): print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") generated_sequences = generated_sequence.tolist() # Decode text text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) # Remove all text after the stop token #text = text[: text.find(args.stop_token) if args.stop_token else None] # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing total_sequence = ( text ) generated_sequences.append(total_sequence) print(total_sequence) st.write("Преобразованный текст: ") st.write(generated_sequences[-1])