Spaces:
Running
Running
without chat mode
Browse files
app.py
CHANGED
@@ -104,24 +104,24 @@ class SessionState:
|
|
104 |
|
105 |
self.cache_clear()
|
106 |
|
107 |
-
# --------------------------------------
|
108 |
-
# Load Chat History as a list
|
109 |
-
# --------------------------------------
|
110 |
-
def load_chat_history(self) -> list:
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
# --------------------------------------
|
127 |
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
@@ -627,9 +627,21 @@ def user(ss: SessionState, query) -> (SessionState, list):
|
|
627 |
return ss, chat_history
|
628 |
|
629 |
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
630 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
ss = qa_predict(ss, query) # LLMで回答を生成
|
632 |
|
|
|
633 |
else:
|
634 |
ss = conversation_prep(ss)
|
635 |
ss = chat_predict(ss, query)
|
@@ -637,7 +649,7 @@ def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
|
637 |
return ss, "" # ssとquery欄(空欄)
|
638 |
|
639 |
def chat_predict(ss: SessionState, query) -> SessionState:
|
640 |
-
response = ss.conversation_chain.predict(
|
641 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
642 |
return ss
|
643 |
|
@@ -694,14 +706,13 @@ def qa_predict(ss: SessionState, query) -> SessionState:
|
|
694 |
|
695 |
# 回答を1文字ずつチャット画面に表示する
|
696 |
def show_response(ss: SessionState) -> str:
|
697 |
-
# chat_history = ss.load_chat_history() # メモリから会話履歴をリスト型で取得
|
698 |
-
# response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
699 |
-
# chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
700 |
-
|
701 |
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
702 |
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
703 |
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
704 |
|
|
|
|
|
|
|
705 |
for character in response:
|
706 |
chat_history[-1][1] += character
|
707 |
time.sleep(0.05)
|
@@ -716,7 +727,11 @@ with gr.Blocks() as demo:
|
|
716 |
# API KEY をセット/クリアする関数
|
717 |
# --------------------------------------
|
718 |
def openai_api_setfn(openai_api_key) -> str:
|
719 |
-
if
|
|
|
|
|
|
|
|
|
720 |
os.environ["OPENAI_API_KEY"] = ""
|
721 |
status_message = "❌ 有効なAPIキーを入力してください"
|
722 |
return status_message
|
@@ -761,7 +776,7 @@ with gr.Blocks() as demo:
|
|
761 |
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
762 |
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
763 |
'text-embedding-ada-002',
|
764 |
-
"None"
|
765 |
],
|
766 |
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
|
767 |
label = 'Embedding model',
|
@@ -779,7 +794,7 @@ with gr.Blocks() as demo:
|
|
779 |
with gr.Row():
|
780 |
with gr.Column():
|
781 |
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
782 |
-
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=
|
783 |
with gr.Column():
|
784 |
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
785 |
with gr.Column():
|
@@ -856,7 +871,7 @@ with gr.Blocks() as demo:
|
|
856 |
interactive=True,
|
857 |
)
|
858 |
with gr.Column(scale=5):
|
859 |
-
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=
|
860 |
query_send_btn = gr.Button(value="▶")
|
861 |
|
862 |
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
@@ -865,5 +880,4 @@ with gr.Blocks() as demo:
|
|
865 |
|
866 |
if __name__ == "__main__":
|
867 |
demo.queue(concurrency_count=5)
|
868 |
-
demo.launch(debug=True, inbrowser=True)
|
869 |
-
|
|
|
104 |
|
105 |
self.cache_clear()
|
106 |
|
107 |
+
# # --------------------------------------
|
108 |
+
# # Load Chat History as a list
|
109 |
+
# # --------------------------------------
|
110 |
+
# def load_chat_history(self) -> list:
|
111 |
+
# chat_history = []
|
112 |
+
# try:
|
113 |
+
# chat_memory = self.memory.load_memory_variables({})['chat_history']
|
114 |
+
# except KeyError:
|
115 |
+
# return chat_history
|
116 |
+
|
117 |
+
# # チャット履歴をペアごとに読み取る
|
118 |
+
# for i in range(0, len(chat_memory), 2):
|
119 |
+
# user_message = chat_memory[i].content
|
120 |
+
# ai_message = ""
|
121 |
+
# if i + 1 < len(chat_memory):
|
122 |
+
# ai_message = chat_memory[i + 1].content
|
123 |
+
# chat_history.append([user_message, ai_message])
|
124 |
+
# return chat_history
|
125 |
|
126 |
# --------------------------------------
|
127 |
# 自作TextSplitter(テキストをLLMのトークン数内に分割)
|
|
|
627 |
return ss, chat_history
|
628 |
|
629 |
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
|
630 |
+
|
631 |
+
if ss.llm is None:
|
632 |
+
response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。"
|
633 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
634 |
+
return ss, ""
|
635 |
+
|
636 |
+
elif qa_flag is True and ss.embeddings is None:
|
637 |
+
response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。"
|
638 |
+
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
639 |
+
|
640 |
+
# QA Model
|
641 |
+
elif qa_flag is True and ss.embeddings is not None:
|
642 |
ss = qa_predict(ss, query) # LLMで回答を生成
|
643 |
|
644 |
+
# Chat Model
|
645 |
else:
|
646 |
ss = conversation_prep(ss)
|
647 |
ss = chat_predict(ss, query)
|
|
|
649 |
return ss, "" # ssとquery欄(空欄)
|
650 |
|
651 |
def chat_predict(ss: SessionState, query) -> SessionState:
|
652 |
+
response = ss.conversation_chain.predict(query=query)
|
653 |
ss.dialogue[-1] = (ss.dialogue[-1][0], response)
|
654 |
return ss
|
655 |
|
|
|
706 |
|
707 |
# 回答を1文字ずつチャット画面に表示する
|
708 |
def show_response(ss: SessionState) -> str:
|
|
|
|
|
|
|
|
|
709 |
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
|
710 |
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
|
711 |
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
|
712 |
|
713 |
+
if response is None:
|
714 |
+
response = "回答を生成できませんでした。"
|
715 |
+
|
716 |
for character in response:
|
717 |
chat_history[-1][1] += character
|
718 |
time.sleep(0.05)
|
|
|
727 |
# API KEY をセット/クリアする関数
|
728 |
# --------------------------------------
|
729 |
def openai_api_setfn(openai_api_key) -> str:
|
730 |
+
if openai_api_key == "kikagaku":
|
731 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo")
|
732 |
+
status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました"
|
733 |
+
return status_message
|
734 |
+
elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
|
735 |
os.environ["OPENAI_API_KEY"] = ""
|
736 |
status_message = "❌ 有効なAPIキーを入力してください"
|
737 |
return status_message
|
|
|
776 |
'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
|
777 |
'oshizo/sbert-jsnli-luke-japanese-base-lite',
|
778 |
'text-embedding-ada-002',
|
779 |
+
# "None"
|
780 |
],
|
781 |
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
|
782 |
label = 'Embedding model',
|
|
|
794 |
with gr.Row():
|
795 |
with gr.Column():
|
796 |
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
|
797 |
+
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True)
|
798 |
with gr.Column():
|
799 |
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
|
800 |
with gr.Column():
|
|
|
871 |
interactive=True,
|
872 |
)
|
873 |
with gr.Column(scale=5):
|
874 |
+
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=False)
|
875 |
query_send_btn = gr.Button(value="▶")
|
876 |
|
877 |
# gr.Examples(["機械学習について説明してください"], inputs=[query])
|
|
|
880 |
|
881 |
if __name__ == "__main__":
|
882 |
demo.queue(concurrency_count=5)
|
883 |
+
demo.launch(debug=True, inbrowser=True)
|
|