cawacci commited on
Commit
15aa1d4
·
1 Parent(s): 38fdae2

without chat mode

Browse files
Files changed (1) hide show
  1. app.py +44 -30
app.py CHANGED
@@ -104,24 +104,24 @@ class SessionState:
104
 
105
  self.cache_clear()
106
 
107
- # --------------------------------------
108
- # Load Chat History as a list
109
- # --------------------------------------
110
- def load_chat_history(self) -> list:
111
- chat_history = []
112
- try:
113
- chat_memory = self.memory.load_memory_variables({})['chat_history']
114
- except KeyError:
115
- return chat_history
116
-
117
- # チャット履歴をペアごとに読み取る
118
- for i in range(0, len(chat_memory), 2):
119
- user_message = chat_memory[i].content
120
- ai_message = ""
121
- if i + 1 < len(chat_memory):
122
- ai_message = chat_memory[i + 1].content
123
- chat_history.append([user_message, ai_message])
124
- return chat_history
125
 
126
  # --------------------------------------
127
  # 自作TextSplitter(テキストをLLMのトークン数内に分割)
@@ -627,9 +627,21 @@ def user(ss: SessionState, query) -> (SessionState, list):
627
  return ss, chat_history
628
 
629
  def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
630
- if qa_flag is True:
 
 
 
 
 
 
 
 
 
 
 
631
  ss = qa_predict(ss, query) # LLMで回答を生成
632
 
 
633
  else:
634
  ss = conversation_prep(ss)
635
  ss = chat_predict(ss, query)
@@ -637,7 +649,7 @@ def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
637
  return ss, "" # ssとquery欄(空欄)
638
 
639
  def chat_predict(ss: SessionState, query) -> SessionState:
640
- response = ss.conversation_chain.predict(input=query)
641
  ss.dialogue[-1] = (ss.dialogue[-1][0], response)
642
  return ss
643
 
@@ -694,14 +706,13 @@ def qa_predict(ss: SessionState, query) -> SessionState:
694
 
695
  # 回答を1文字ずつチャット画面に表示する
696
  def show_response(ss: SessionState) -> str:
697
- # chat_history = ss.load_chat_history() # メモリから会話履歴をリスト型で取得
698
- # response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
699
- # chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
700
-
701
  chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
702
  response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
703
  chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
704
 
 
 
 
705
  for character in response:
706
  chat_history[-1][1] += character
707
  time.sleep(0.05)
@@ -716,7 +727,11 @@ with gr.Blocks() as demo:
716
  # API KEY をセット/クリアする関数
717
  # --------------------------------------
718
  def openai_api_setfn(openai_api_key) -> str:
719
- if not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
 
 
 
 
720
  os.environ["OPENAI_API_KEY"] = ""
721
  status_message = "❌ 有効なAPIキーを入力してください"
722
  return status_message
@@ -761,7 +776,7 @@ with gr.Blocks() as demo:
761
  'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
762
  'oshizo/sbert-jsnli-luke-japanese-base-lite',
763
  'text-embedding-ada-002',
764
- "None"
765
  ],
766
  value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
767
  label = 'Embedding model',
@@ -779,7 +794,7 @@ with gr.Blocks() as demo:
779
  with gr.Row():
780
  with gr.Column():
781
  load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
782
- verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=False)
783
  with gr.Column():
784
  temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
785
  with gr.Column():
@@ -856,7 +871,7 @@ with gr.Blocks() as demo:
856
  interactive=True,
857
  )
858
  with gr.Column(scale=5):
859
- qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=True)
860
  query_send_btn = gr.Button(value="▶")
861
 
862
  # gr.Examples(["機械学習について説明してください"], inputs=[query])
@@ -865,5 +880,4 @@ with gr.Blocks() as demo:
865
 
866
  if __name__ == "__main__":
867
  demo.queue(concurrency_count=5)
868
- demo.launch(debug=True, inbrowser=True)
869
-
 
104
 
105
  self.cache_clear()
106
 
107
+ # # --------------------------------------
108
+ # # Load Chat History as a list
109
+ # # --------------------------------------
110
+ # def load_chat_history(self) -> list:
111
+ # chat_history = []
112
+ # try:
113
+ # chat_memory = self.memory.load_memory_variables({})['chat_history']
114
+ # except KeyError:
115
+ # return chat_history
116
+
117
+ # # チャット履歴をペアごとに読み取る
118
+ # for i in range(0, len(chat_memory), 2):
119
+ # user_message = chat_memory[i].content
120
+ # ai_message = ""
121
+ # if i + 1 < len(chat_memory):
122
+ # ai_message = chat_memory[i + 1].content
123
+ # chat_history.append([user_message, ai_message])
124
+ # return chat_history
125
 
126
  # --------------------------------------
127
  # 自作TextSplitter(テキストをLLMのトークン数内に分割)
 
627
  return ss, chat_history
628
 
629
  def bot(ss: SessionState, query, qa_flag) -> (SessionState, str):
630
+
631
+ if ss.llm is None:
632
+ response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。"
633
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
634
+ return ss, ""
635
+
636
+ elif qa_flag is True and ss.embeddings is None:
637
+ response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。"
638
+ ss.dialogue[-1] = (ss.dialogue[-1][0], response)
639
+
640
+ # QA Model
641
+ elif qa_flag is True and ss.embeddings is not None:
642
  ss = qa_predict(ss, query) # LLMで回答を生成
643
 
644
+ # Chat Model
645
  else:
646
  ss = conversation_prep(ss)
647
  ss = chat_predict(ss, query)
 
649
  return ss, "" # ssとquery欄(空欄)
650
 
651
  def chat_predict(ss: SessionState, query) -> SessionState:
652
+ response = ss.conversation_chain.predict(query=query)
653
  ss.dialogue[-1] = (ss.dialogue[-1][0], response)
654
  return ss
655
 
 
706
 
707
  # 回答を1文字ずつチャット画面に表示する
708
  def show_response(ss: SessionState) -> str:
 
 
 
 
709
  chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得
710
  response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避
711
  chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする
712
 
713
+ if response is None:
714
+ response = "回答を生成できませんでした。"
715
+
716
  for character in response:
717
  chat_history[-1][1] += character
718
  time.sleep(0.05)
 
727
  # API KEY をセット/クリアする関数
728
  # --------------------------------------
729
  def openai_api_setfn(openai_api_key) -> str:
730
+ if openai_api_key == "kikagaku":
731
+ os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo")
732
+ status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました"
733
+ return status_message
734
+ elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50:
735
  os.environ["OPENAI_API_KEY"] = ""
736
  status_message = "❌ 有効なAPIキーを入力してください"
737
  return status_message
 
776
  'sonoisa/sentence-bert-base-ja-mean-tokens-v2',
777
  'oshizo/sbert-jsnli-luke-japanese-base-lite',
778
  'text-embedding-ada-002',
779
+ # "None"
780
  ],
781
  value="sonoisa/sentence-bert-base-ja-mean-tokens-v2",
782
  label = 'Embedding model',
 
794
  with gr.Row():
795
  with gr.Column():
796
  load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True)
797
+ verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True)
798
  with gr.Column():
799
  temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True)
800
  with gr.Column():
 
871
  interactive=True,
872
  )
873
  with gr.Column(scale=5):
874
+ qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=False)
875
  query_send_btn = gr.Button(value="▶")
876
 
877
  # gr.Examples(["機械学習について説明してください"], inputs=[query])
 
880
 
881
  if __name__ == "__main__":
882
  demo.queue(concurrency_count=5)
883
+ demo.launch(debug=True, inbrowser=True)