kb2022 commited on
Commit
34d1256
·
verified ·
1 Parent(s): 1f690c3

[add] probability mode

Browse files
Files changed (1) hide show
  1. app.py +41 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
 
3
  from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word
4
- from utils.llm import load_llm_from_pretrained, inference
5
 
6
 
7
  wordlist_1_path_s = "input/ng_wordlists/ng_wordlist_1_sexual.txt"
@@ -53,11 +53,38 @@ def detect_ng_word(input_text):
53
 
54
  return response
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # 会話履歴用リスト型変数
58
  message_history = []
59
 
60
- def chat(user_msg):
61
  """
62
  AIとの会話を実行後、全会話履歴を返す
63
  user_msg: 入力されたユーザのメッセージ
@@ -71,7 +98,11 @@ def chat(user_msg):
71
  })
72
 
73
  # AIの回答を履歴に追加
74
- response = detect_ng_word(user_msg)
 
 
 
 
75
  assistant_msg = " ".join(response)
76
  message_history.append({
77
  "role": "assistant",
@@ -86,7 +117,12 @@ with gr.Blocks() as demo:
86
  # チャットボットUI処理
87
  chatbot = gr.Chatbot()
88
  input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください")
89
- input.submit(fn=chat, inputs=input, outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示
 
 
 
 
 
90
  input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア
91
 
92
- demo.launch()
 
1
  import gradio as gr
2
 
3
  from utils.find_ng_word import get_ng_wordlist, get_ng_wordlist_from_saved, search_ng_word
4
+ from utils.llm import load_llm_from_pretrained, inference, estimate_probability
5
 
6
 
7
  wordlist_1_path_s = "input/ng_wordlists/ng_wordlist_1_sexual.txt"
 
53
 
54
  return response
55
 
56
+ def estimate_ng_probability(input_text, thresold=0.3):
57
+ response = []
58
+ rtn_s = search_ng_word(input_text, ng_wordlist_1_s, ng_wordlist_2_s)
59
+ rtn_o = search_ng_word(input_text, ng_wordlist_1_o, ng_wordlist_2_o)
60
+ rtn = rtn_s + rtn_o
61
+
62
+ if len(rtn) == 0:
63
+ response.append("NGワードは検知されませんでした \n")
64
+ else:
65
+ response.append('以下のNGワードを検知しました \n')
66
+ for rtn_i in rtn:
67
+ ng_word = str(rtn_i) + " \n"
68
+ response.append(ng_word)
69
+
70
+ rtn_s = [ri + "(sexual)" for ri in rtn_s]
71
+ rtn_o = [ri + "(offensive)" for ri in rtn_o]
72
+ ngword_with_label = rtn_s + rtn_o
73
+
74
+ prob = estimate_probability(model, tokenizer, input_text, ngword_with_label)
75
+
76
+ if prob > threshold:
77
+ response.append(f"不適切な内容を検知しました(NG度:{100*prob:.0f}%)")
78
+ else:
79
+ response.append(f"不適切な内容は検知されませんでした(NG度:{100*prob:.0f}%)")
80
+
81
+ return response
82
+
83
 
84
  # 会話履歴用リスト型変数
85
  message_history = []
86
 
87
+ def chat(user_msg, mode, threshold):
88
  """
89
  AIとの会話を実行後、全会話履歴を返す
90
  user_msg: 入力されたユーザのメッセージ
 
98
  })
99
 
100
  # AIの回答を履歴に追加
101
+ if mode == "理由を出力":
102
+ response = detect_ng_word(user_msg)
103
+ elif mode == "確率を出力":
104
+ response = estimate_ng_probability(user_msg, threshold)
105
+
106
  assistant_msg = " ".join(response)
107
  message_history.append({
108
  "role": "assistant",
 
117
  # チャットボットUI処理
118
  chatbot = gr.Chatbot()
119
  input = gr.Textbox(show_label=False, placeholder="チェックしたい文章を入力してください")
120
+ mode = gr.Radio(["理由を出力", "確率を出力"], label="モードを選択")
121
+ if mode == "確率を出力":
122
+ threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="NG度の閾値")
123
+ else:
124
+ threshold = None
125
+ input.submit(fn=chat, inputs=[input, mode, threshold], outputs=chatbot) # メッセージ送信されたら、AIと会話してチャット欄に全会話内容を表示
126
  input.submit(fn=lambda: "", inputs=None, outputs=input) # (上記に加えて)入力欄をクリア
127
 
128
+ demo.launch(share=True)