import gradio as gr from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word from utils.llm import load_llm_from_pretrained, inference, estimate_probability wordlist_1_path_s = "input/ng_wordlists/ng_wordlist_1_sexual.txt" wordlist_2_path_s = "input/ng_wordlists/ng_wordlist_2_sexual.txt" wordlist_1_path_o = "input/ng_wordlists/ng_wordlist_1_offensive.txt" wordlist_2_path_o = "input/ng_wordlists/ng_wordlist_2_offensive.txt" pretrained_model_path = "input/llm_weights" print("モデルをロード") ng_wordlist_1_s, ng_wordlist_2_s = get_ng_wordlist_from_saved(wordlist_1_path_s, wordlist_2_path_s) ng_wordlist_1_o, ng_wordlist_2_o = get_ng_wordlist_from_saved(wordlist_1_path_o, wordlist_2_path_o) model, tokenizer = load_llm_from_pretrained(pretrained_model_path) # 検出結果を生成 def detect_ng_word(input_text): response = [] rtn_s = search_ng_word(input_text, ng_wordlist_1_s, ng_wordlist_2_s) rtn_o = search_ng_word(input_text, ng_wordlist_1_o, ng_wordlist_2_o) rtn = rtn_s + rtn_o if len(rtn) == 0: response.append("NGワードは検知されませんでした \n") else: response.append('以下のNGワードを検知しました \n') for rtn_i in rtn: ng_word = str(rtn_i) + " \n" response.append(ng_word) rtn_s = [ri + "(sexual)" for ri in rtn_s] rtn_o = [ri + "(offensive)" for ri in rtn_o] ngword_with_label = rtn_s + rtn_o output = inference(model, tokenizer, input_text, ngword_with_label) if output == "はい。攻撃的だから。": response.append('不適切な内容を検知しました(攻撃的)') elif output == "はい。暴力的だから。": response.append('不適切な内容を検知しました(暴力的)') elif output == "はい。差別的だから。": response.append('不適切な内容を検知しました(差別的)') elif output == "はい。性的だから。": response.append('不適切な内容を検知しました(性的)') elif output == "はい。政治的だから。": response.append('不適切な内容を検知しました(政治的)') else: response.append("不適切な内容は検知されませんでした") return response def estimate_ng_probability(input_text, threshold=0.3): response = [] rtn_s = search_ng_word(input_text, ng_wordlist_1_s, ng_wordlist_2_s) rtn_o = search_ng_word(input_text, ng_wordlist_1_o, ng_wordlist_2_o) rtn = rtn_s + rtn_o if len(rtn) == 0: response.append("NGワードは検知されませんでした \n") else: response.append('以下のNGワードを検知しました \n') for rtn_i in rtn: ng_word = str(rtn_i) + " \n" response.append(ng_word) rtn_s = [ri + "(sexual)" for ri in rtn_s] rtn_o = [ri + "(offensive)" for ri in rtn_o] ngword_with_label = rtn_s + rtn_o prob = estimate_probability(model, tokenizer, input_text, ngword_with_label) if prob > threshold: response.append(f"不適切な内容を検知しました(NG度:{100*prob:.0f}%)") else: response.append(f"不適切な内容は検知されませんでした(NG度:{100*prob:.0f}%)") return response # 会話履歴用リスト型変数 message_history = [] def chat(user_msg, mode, threshold=None): """ AIとの会話を実行後、全会話履歴を返す user_msg: 入力されたユーザのメッセージ """ global message_history # ユーザの会話を履歴に追加 message_history.append({ "role": "user", "content": user_msg }) # AIの回答を履歴に追加 if mode == "理由を出力": response = detect_ng_word(user_msg) elif mode == "確率を出力": response = estimate_ng_probability(user_msg, threshold) assistant_msg = " ".join(response) message_history.append({ "role": "assistant", "content": assistant_msg }) # 全会話履歴をChatbot用タプル・リストに変換して返す return [(message_history[i]["content"], message_history[i+1]["content"]) for i in range(0, len(message_history)-1, 2)] with gr.Blocks() as demo: # チャットボットUI処理 chatbot = gr.Chatbot() input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください") inputs=[input, gr.Radio(["確率を出力", "理由を出力"], value= "確率を出力", label="モードを選択"), gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="NG度の閾値")] input.submit(fn=chat, inputs=inputs, outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示 input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア demo.launch(share=True)