Spaces:
Sleeping
Sleeping
without chat mode
Browse files
app.py
CHANGED
|
@@ -104,24 +104,24 @@ class SessionState:
|
|
| 104 |
|
| 105 |
self.cache_clear()
|
| 106 |
|
| 107 |
-
# --------------------------------------
|
| 108 |
-
# Load Chat History as a list
|
| 109 |
-
# --------------------------------------
|
| 110 |
-
def load_chat_history(self) -> list:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
# --------------------------------------
|
| 127 |
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
|
@@ -627,9 +627,21 @@ def user(ss: SessionState, query) -> (SessionState, list):
|
|
| 627 |
return ss, chat_history
|
| 628 |
|
| 629 |
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
| 630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
ss = qa_predict(ss, query) # LLMで回答を生成
|
| 632 |
|
|
|
|
| 633 |
else:
|
| 634 |
ss = conversation_prep(ss)
|
| 635 |
ss = chat_predict(ss, query)
|
|
@@ -637,7 +649,7 @@ def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
|
| 637 |
return ss, "" # ssとquery欄(空欄)
|
| 638 |
|
| 639 |
def chat_predict(ss: SessionState, query) -> SessionState:
|
| 640 |
-
response = ss.conversation_chain.predict(
|
| 641 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 642 |
return ss
|
| 643 |
|
|
@@ -694,14 +706,13 @@ def qa_predict(ss: SessionState, query) -> SessionState:
|
|
| 694 |
|
| 695 |
# 回答を1文字ずつチャット画面に表示する
|
| 696 |
def show_response(ss: SessionState) -> str:
|
| 697 |
-
# chat_history = ss.load_chat_history() # メモリから会話履歴をリスト型で取得
|
| 698 |
-
# response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
| 699 |
-
# chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
| 700 |
-
|
| 701 |
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
| 702 |
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
| 703 |
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
| 704 |
|
|
|
|
|
|
|
|
|
|
| 705 |
for character in response:
|
| 706 |
chat_history[-1][1] += character
|
| 707 |
time.sleep(0.05)
|
|
@@ -716,7 +727,11 @@ with gr.Blocks() as demo:
|
|
| 716 |
# API KEY をセット/クリアする関数
|
| 717 |
# --------------------------------------
|
| 718 |
def openai_api_setfn(openai_api_key) -> str:
|
| 719 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
os.environ["OPENAI_API_KEY"] = ""
|
| 721 |
status_message = "❌ 有効なAPIキーを入力してください"
|
| 722 |
return status_message
|
|
@@ -761,7 +776,7 @@ with gr.Blocks() as demo:
|
|
| 761 |
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
| 762 |
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
| 763 |
'text-embedding-ada-002',
|
| 764 |
-
"None"
|
| 765 |
],
|
| 766 |
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
|
| 767 |
label = 'Embedding model',
|
|
@@ -779,7 +794,7 @@ with gr.Blocks() as demo:
|
|
| 779 |
with gr.Row():
|
| 780 |
with gr.Column():
|
| 781 |
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
| 782 |
-
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=
|
| 783 |
with gr.Column():
|
| 784 |
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
| 785 |
with gr.Column():
|
|
@@ -856,7 +871,7 @@ with gr.Blocks() as demo:
|
|
| 856 |
interactive=True,
|
| 857 |
)
|
| 858 |
with gr.Column(scale=5):
|
| 859 |
-
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=
|
| 860 |
query_send_btn = gr.Button(value="▶")
|
| 861 |
|
| 862 |
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
|
@@ -865,5 +880,4 @@ with gr.Blocks() as demo:
|
|
| 865 |
|
| 866 |
if __name__ == "__main__":
|
| 867 |
demo.queue(concurrency_count=5)
|
| 868 |
-
demo.launch(debug=True, inbrowser=True)
|
| 869 |
-
|
|
|
|
| 104 |
|
| 105 |
self.cache_clear()
|
| 106 |
|
| 107 |
+
# # --------------------------------------
|
| 108 |
+
# # Load Chat History as a list
|
| 109 |
+
# # --------------------------------------
|
| 110 |
+
# def load_chat_history(self) -> list:
|
| 111 |
+
# chat_history = []
|
| 112 |
+
# try:
|
| 113 |
+
# chat_memory = self.memory.load_memory_variables({})['chat_history']
|
| 114 |
+
# except KeyError:
|
| 115 |
+
# return chat_history
|
| 116 |
+
|
| 117 |
+
# # チャット履歴をペアごとに読み取る
|
| 118 |
+
# for i in range(0, len(chat_memory), 2):
|
| 119 |
+
# user_message = chat_memory[i].content
|
| 120 |
+
# ai_message = ""
|
| 121 |
+
# if i + 1 < len(chat_memory):
|
| 122 |
+
# ai_message = chat_memory[i + 1].content
|
| 123 |
+
# chat_history.append([user_message, ai_message])
|
| 124 |
+
# return chat_history
|
| 125 |
|
| 126 |
# --------------------------------------
|
| 127 |
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
|
|
|
| 627 |
return ss, chat_history
|
| 628 |
|
| 629 |
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
| 630 |
+
|
| 631 |
+
if ss.llm is None:
|
| 632 |
+
response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。"
|
| 633 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 634 |
+
return ss, ""
|
| 635 |
+
|
| 636 |
+
elif qa_flag is True and ss.embeddings is None:
|
| 637 |
+
response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。"
|
| 638 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 639 |
+
|
| 640 |
+
# QA Model
|
| 641 |
+
elif qa_flag is True and ss.embeddings is not None:
|
| 642 |
ss = qa_predict(ss, query) # LLMで回答を生成
|
| 643 |
|
| 644 |
+
# Chat Model
|
| 645 |
else:
|
| 646 |
ss = conversation_prep(ss)
|
| 647 |
ss = chat_predict(ss, query)
|
|
|
|
| 649 |
return ss, "" # ssとquery欄(空欄)
|
| 650 |
|
| 651 |
def chat_predict(ss: SessionState, query) -> SessionState:
|
| 652 |
+
response = ss.conversation_chain.predict(query=query)
|
| 653 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
| 654 |
return ss
|
| 655 |
|
|
|
|
| 706 |
|
| 707 |
# 回答を1文字ずつチャット画面に表示する
|
| 708 |
def show_response(ss: SessionState) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
| 710 |
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
| 711 |
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
| 712 |
|
| 713 |
+
if response is None:
|
| 714 |
+
response = "回答を生成できませんでした。"
|
| 715 |
+
|
| 716 |
for character in response:
|
| 717 |
chat_history[-1][1] += character
|
| 718 |
time.sleep(0.05)
|
|
|
|
| 727 |
# API KEY をセット/クリアする関数
|
| 728 |
# --------------------------------------
|
| 729 |
def openai_api_setfn(openai_api_key) -> str:
|
| 730 |
+
if openai_api_key == "kikagaku":
|
| 731 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo")
|
| 732 |
+
status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました"
|
| 733 |
+
return status_message
|
| 734 |
+
elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
|
| 735 |
os.environ["OPENAI_API_KEY"] = ""
|
| 736 |
status_message = "❌ 有効なAPIキーを入力してください"
|
| 737 |
return status_message
|
|
|
|
| 776 |
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
| 777 |
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
| 778 |
'text-embedding-ada-002',
|
| 779 |
+
# "None"
|
| 780 |
],
|
| 781 |
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
|
| 782 |
label = 'Embedding model',
|
|
|
|
| 794 |
with gr.Row():
|
| 795 |
with gr.Column():
|
| 796 |
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
| 797 |
+
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True)
|
| 798 |
with gr.Column():
|
| 799 |
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
| 800 |
with gr.Column():
|
|
|
|
| 871 |
interactive=True,
|
| 872 |
)
|
| 873 |
with gr.Column(scale=5):
|
| 874 |
+
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=False)
|
| 875 |
query_send_btn = gr.Button(value="▶")
|
| 876 |
|
| 877 |
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
|
|
|
| 880 |
|
| 881 |
if __name__ == "__main__":
|
| 882 |
demo.queue(concurrency_count=5)
|
| 883 |
+
demo.launch(debug=True, inbrowser=True)
|
|
|