File size: 2,580 Bytes
c6eaf06 eb91b0e 149af33 eb91b0e e450399 eb91b0e f991e69 d7af6d0 eb91b0e d7af6d0 1fe7ac7 2adce53 c5a7489 eb91b0e 661df71 763cd70 d3b2b57 eb91b0e 31ff417 763cd70 d3b2b57 eb91b0e 84fbaf6 eb91b0e 3d5ce55 eb91b0e c99011b eb91b0e f991e69 eb91b0e db6b790 eb91b0e c6eaf06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelWithLMHead
import torch
torch.manual_seed(0)
model_name = "orzhan/rut5-base-detox-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
@st.cache
def load_model(model_name):
model = AutoModelWithLMHead.from_pretrained(model_name)
return model
model = load_model(model_name)
def infer(input_ids):
output_sequences = model.generate(
input_ids=input_ids,
max_length=40,
do_sample=False,
num_return_sequences=1,
num_beams=8,
length_penalty=1.0,
no_repeat_ngram_size=3,
# do_sample=True, top_p=0.9, temperature=0.9
)
return output_sequences
default_value = "всегда ненавидел этих ублюдочных тварей"
examples = [default_value,
"убила бы этих выродков и их родителей.",
"пошли вы сука все на хуй со своим коронавирусом...",
"Перед барином выслуживаешься холоп?",
"просто , зарплату шакалов отрабатывают , а больше на крысенышей похожи...",
"--- свой текст ---"]
#prompts
st.title("Демо детоксификации на ruT5")
sent = st.selectbox("Пример", examples)
if sent == "--- свой текст ---":
sent = st.text_area("Исходный текст", default_value)
#if custom_sent == default_value:
# custom_sent = sent
st.button('Сделать нетоксичным')
encoded_prompt = tokenizer.encode(sent, add_special_tokens=True, return_tensors="pt")
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = infer(input_ids)
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequences = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
# Remove all text after the stop token
#text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
text
)
generated_sequences.append(total_sequence)
print(total_sequence)
st.write("Преобразованный текст: ")
st.write(generated_sequences[-1])
|