File size: 2,580 Bytes
c6eaf06
eb91b0e
 
149af33
 
eb91b0e
e450399
eb91b0e
 
 
 
 
 
 
 
 
 
 
 
 
f991e69
d7af6d0
eb91b0e
d7af6d0
1fe7ac7
2adce53
c5a7489
eb91b0e
 
 
661df71
763cd70
 
 
 
d3b2b57
 
eb91b0e
 
31ff417
763cd70
d3b2b57
 
 
 
eb91b0e
84fbaf6
eb91b0e
3d5ce55
eb91b0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c99011b
eb91b0e
 
 
 
 
 
f991e69
eb91b0e
 
 
 
 
db6b790
eb91b0e
c6eaf06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st
import transformers
from transformers import AutoTokenizer, AutoModelWithLMHead
import torch
torch.manual_seed(0)

model_name = "orzhan/rut5-base-detox-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
@st.cache
def load_model(model_name):
	    model = AutoModelWithLMHead.from_pretrained(model_name)
	    return model
	
model = load_model(model_name)


def infer(input_ids):
    
    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=40,
        do_sample=False,
        num_return_sequences=1,
        num_beams=8,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
      #  do_sample=True, top_p=0.9, temperature=0.9
    )

    return output_sequences
default_value = "всегда ненавидел этих ублюдочных тварей"
examples = [default_value,
"убила бы этих выродков и их родителей.",
"пошли вы сука все на хуй со своим коронавирусом...",
"Перед барином выслуживаешься холоп?",
"просто , зарплату шакалов отрабатывают , а больше на крысенышей похожи...",
"--- свой текст ---"]

#prompts
st.title("Демо детоксификации на ruT5")
sent = st.selectbox("Пример", examples)
if sent == "--- свой текст ---":
    sent = st.text_area("Исходный текст", default_value)
#if custom_sent == default_value:
#    custom_sent = sent

st.button('Сделать нетоксичным')

encoded_prompt = tokenizer.encode(sent, add_special_tokens=True, return_tensors="pt")
if encoded_prompt.size()[-1] == 0:
    input_ids = None
else:
    input_ids = encoded_prompt


output_sequences = infer(input_ids)



for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
    print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
    generated_sequences = generated_sequence.tolist()

    # Decode text
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)

    # Remove all text after the stop token
    #text = text[: text.find(args.stop_token) if args.stop_token else None]

    # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
    total_sequence = (
        text
    )

    generated_sequences.append(total_sequence)
    print(total_sequence)

st.write("Преобразованный текст: ")
st.write(generated_sequences[-1])