File size: 5,007 Bytes
e7a412f
 
 
34d1256
e7a412f
 
352aba6
 
 
 
e7a412f
 
 
 
 
 
 
 
 
 
 
 
1b44d13
 
e7a412f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80086ff
34d1256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a412f
 
 
 
9c8197a
e7a412f
 
 
 
 
 
 
 
 
 
 
 
 
34d1256
 
 
 
 
e7a412f
 
 
 
 
 
 
 
 
 
 
 
 
9c8197a
82e75aa
8d0fb13
 
82e75aa
8d0fb13
9c8197a
e7a412f
 
34d1256
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr

from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word
from utils.llm import load_llm_from_pretrained, inference, estimate_probability


wordlist_1_path_s = "input/ng_wordlists/ng_wordlist_1_sexual.txt"
wordlist_2_path_s = "input/ng_wordlists/ng_wordlist_2_sexual.txt"
wordlist_1_path_o = "input/ng_wordlists/ng_wordlist_1_offensive.txt"
wordlist_2_path_o = "input/ng_wordlists/ng_wordlist_2_offensive.txt"

pretrained_model_path = "input/llm_weights"

print("モデルをロード")
ng_wordlist_1_s, ng_wordlist_2_s = get_ng_wordlist_from_saved(wordlist_1_path_s, wordlist_2_path_s)
ng_wordlist_1_o, ng_wordlist_2_o = get_ng_wordlist_from_saved(wordlist_1_path_o, wordlist_2_path_o)

model, tokenizer = load_llm_from_pretrained(pretrained_model_path)

# 検出結果を生成
def detect_ng_word(input_text):
    response = []
    rtn_s = search_ng_word(input_text, ng_wordlist_1_s, ng_wordlist_2_s)
    rtn_o = search_ng_word(input_text, ng_wordlist_1_o, ng_wordlist_2_o)
    rtn = rtn_s + rtn_o

    if len(rtn) == 0:
        response.append("NGワードは検知されませんでした  \n")
    else:
        response.append('以下のNGワードを検知しました  \n')
        for rtn_i in rtn:
            ng_word = str(rtn_i) + "  \n"
            response.append(ng_word)

    rtn_s = [ri + "(sexual)" for ri in rtn_s]
    rtn_o = [ri + "(offensive)" for ri in rtn_o]
    ngword_with_label = rtn_s + rtn_o
    
    output = inference(model, tokenizer, input_text, ngword_with_label)
    
    if output == "はい。攻撃的だから。</s>":
        response.append('不適切な内容を検知しました(攻撃的)')
    elif output == "はい。暴力的だから。</s>":
        response.append('不適切な内容を検知しました(暴力的)')
    elif output == "はい。差別的だから。</s>":
        response.append('不適切な内容を検知しました(差別的)')
    elif output == "はい。性的だから。</s>":
        response.append('不適切な内容を検知しました(性的)')
    elif output == "はい。政治的だから。</s>":
        response.append('不適切な内容を検知しました(政治的)')
    else:
        response.append("不適切な内容は検知されませんでした")
      
    return response

def estimate_ng_probability(input_text, threshold=0.3):
    response = []
    rtn_s = search_ng_word(input_text, ng_wordlist_1_s, ng_wordlist_2_s)
    rtn_o = search_ng_word(input_text, ng_wordlist_1_o, ng_wordlist_2_o)
    rtn = rtn_s + rtn_o

    if len(rtn) == 0:
        response.append("NGワードは検知されませんでした  \n")
    else:
        response.append('以下のNGワードを検知しました  \n')
        for rtn_i in rtn:
            ng_word = str(rtn_i) + "  \n"
            response.append(ng_word)

    rtn_s = [ri + "(sexual)" for ri in rtn_s]
    rtn_o = [ri + "(offensive)" for ri in rtn_o]
    ngword_with_label = rtn_s + rtn_o
    
    prob = estimate_probability(model, tokenizer, input_text, ngword_with_label)
    
    if prob > threshold:
        response.append(f"不適切な内容を検知しました(NG度:{100*prob:.0f}%)")
    else:
        response.append(f"不適切な内容は検知されませんでした(NG度:{100*prob:.0f}%)")
      
    return response


# 会話履歴用リスト型変数
message_history = []

def chat(user_msg, mode, threshold=None):
    """
    AIとの会話を実行後、全会話履歴を返す
    user_msg: 入力されたユーザのメッセージ
    """
    global message_history

    # ユーザの会話を履歴に追加
    message_history.append({
        "role": "user",
        "content": user_msg
    })

    # AIの回答を履歴に追加
    if mode == "理由を出力":
        response = detect_ng_word(user_msg)
    elif mode == "確率を出力":
        response = estimate_ng_probability(user_msg, threshold)

    assistant_msg = " ".join(response)
    message_history.append({
        "role": "assistant",
        "content": assistant_msg
    })

    # 全会話履歴をChatbot用タプル・リストに変換して返す
    return [(message_history[i]["content"], message_history[i+1]["content"]) for i in range(0, len(message_history)-1, 2)]


with gr.Blocks() as demo:
    # チャットボットUI処理
    chatbot = gr.Chatbot()

    input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください")

    inputs=[input, gr.Radio(["確率を出力", "理由を出力"], value= "確率を出力", label="モードを選択"), gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="NG度の閾値")]
    
    input.submit(fn=chat, inputs=inputs, outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示
    
    input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア

demo.launch(share=True)