Tuchuanhuhuhu commited on
Commit
4b9ef74
·
1 Parent(s): 67474f7

feat: 加入 Azure OpenAI 支持

Browse files
config_example.json CHANGED
@@ -8,6 +8,12 @@
8
  "minimax_api_key": "", // 你的 MiniMax API Key,用于 MiniMax 对话模型
9
  "minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
10
 
 
 
 
 
 
 
11
  //== 基础配置 ==
12
  "language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
13
  "users": [], // 用户列表,[[用户名1, 密码1], [用户名2, 密码2], ...]
 
8
  "minimax_api_key": "", // 你的 MiniMax API Key,用于 MiniMax 对话模型
9
  "minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
10
 
11
+ //== Azure ==
12
+ "azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
13
+ "azure_api_base_url": "", // 你的 Azure Base URL
14
+ "azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
15
+ "azure_deployment_name": "", // 你的 Azure DEPLOYMENT NAME
16
+
17
  //== 基础配置 ==
18
  "language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
19
  "users": [], // 用户列表,[[用户名1, 密码1], [用户名2, 密码2], ...]
modules/config.py CHANGED
@@ -39,6 +39,12 @@ if os.path.exists("config.json"):
39
  else:
40
  config = {}
41
 
 
 
 
 
 
 
42
  sensitive_id = config.get("sensitive_id", "")
43
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
44
 
@@ -97,6 +103,8 @@ os.environ["MINIMAX_API_KEY"] = minimax_api_key
97
  minimax_group_id = config.get("minimax_group_id", "")
98
  os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
99
 
 
 
100
 
101
  usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
102
 
 
39
  else:
40
  config = {}
41
 
42
+ def load_config_to_environ(key_list):
43
+ global config
44
+ for key in key_list:
45
+ if key in config:
46
+ os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
47
+
48
  sensitive_id = config.get("sensitive_id", "")
49
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
50
 
 
103
  minimax_group_id = config.get("minimax_group_id", "")
104
  os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
105
 
106
+ load_config_to_environ(["azure_openai_api_key", "azure_api_base_url", "azure_openai_api_version", "azure_deployment_name"])
107
+
108
 
109
  usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
110
 
modules/models/azure.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import AzureChatOpenAI
2
+ import os
3
+
4
+ from .base_model import Base_Chat_Langchain_Client
5
+
6
+ # load_config_to_environ(["azure_openai_api_key", "azure_api_base_url", "azure_openai_api_version", "azure_deployment_name"])
7
+
8
+ class Azure_OpenAI_Client(Base_Chat_Langchain_Client):
9
+ def setup_model(self):
10
+ # inplement this to setup the model then return it
11
+ return AzureChatOpenAI(
12
+ openai_api_base=os.environ["AZURE_API_BASE_URL"],
13
+ openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
14
+ deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
15
+ openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
16
+ openai_api_type="azure",
17
+ )
modules/models/base_model.py CHANGED
@@ -29,6 +29,8 @@ from langchain.input import print_text
29
  from langchain.schema import AgentAction, AgentFinish, LLMResult
30
  from threading import Thread, Condition
31
  from collections import deque
 
 
32
 
33
  from ..presets import *
34
  from ..index_func import *
@@ -36,6 +38,7 @@ from ..utils import *
36
  from .. import shared
37
  from ..config import retrieve_proxy
38
 
 
39
  class CallbackToIterator:
40
  def __init__(self):
41
  self.queue = deque()
@@ -52,7 +55,8 @@ class CallbackToIterator:
52
 
53
  def __next__(self):
54
  with self.cond:
55
- while not self.queue and not self.finished: # Wait for a value to be added to the queue.
 
56
  self.cond.wait()
57
  if not self.queue:
58
  raise StopIteration()
@@ -63,6 +67,7 @@ class CallbackToIterator:
63
  self.finished = True
64
  self.cond.notify() # Wake up the generator if it's waiting.
65
 
 
66
  def get_action_description(text):
67
  match = re.search('```(.*?)```', text, re.S)
68
  json_text = match.group(1)
@@ -76,6 +81,7 @@ def get_action_description(text):
76
  else:
77
  return ""
78
 
 
79
  class ChuanhuCallbackHandler(BaseCallbackHandler):
80
 
81
  def __init__(self, callback) -> None:
@@ -117,6 +123,10 @@ class ChuanhuCallbackHandler(BaseCallbackHandler):
117
  """Run on new LLM token. Only available when streaming is enabled."""
118
  self.callback(token)
119
 
 
 
 
 
120
 
121
  class ModelType(Enum):
122
  Unknown = -1
@@ -130,6 +140,7 @@ class ModelType(Enum):
130
  Minimax = 7
131
  ChuanhuAgent = 8
132
  GooglePaLM = 9
 
133
 
134
  @classmethod
135
  def get_type(cls, model_name: str):
@@ -155,6 +166,8 @@ class ModelType(Enum):
155
  model_type = ModelType.ChuanhuAgent
156
  elif "palm" in model_name_lower:
157
  model_type = ModelType.GooglePaLM
 
 
158
  else:
159
  model_type = ModelType.Unknown
160
  return model_type
@@ -164,7 +177,7 @@ class BaseLLMModel:
164
  def __init__(
165
  self,
166
  model_name,
167
- system_prompt="",
168
  temperature=1.0,
169
  top_p=1.0,
170
  n_choices=1,
@@ -204,7 +217,8 @@ class BaseLLMModel:
204
  conversations are stored in self.history, with the most recent question, in OpenAI format
205
  should return a generator, each time give the next word (str) in the answer
206
  """
207
- logging.warning("stream predict not implemented, using at once predict instead")
 
208
  response, _ = self.get_answer_at_once()
209
  yield response
210
 
@@ -215,7 +229,8 @@ class BaseLLMModel:
215
  the answer (str)
216
  total token count (int)
217
  """
218
- logging.warning("at once predict not implemented, using stream predict instead")
 
219
  response_iter = self.get_answer_stream_iter()
220
  count = 0
221
  for response in response_iter:
@@ -276,9 +291,11 @@ class BaseLLMModel:
276
  self.history[-2] = construct_user(fake_input)
277
  chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
278
  if fake_input is not None:
279
- self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
 
280
  else:
281
- self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
 
282
  status_text = self.token_message()
283
  return chatbot, status_text
284
 
@@ -302,10 +319,13 @@ class BaseLLMModel:
302
  from langchain.chat_models import ChatOpenAI
303
  from langchain.callbacks import StdOutCallbackHandler
304
  prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
305
- PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
 
306
  llm = ChatOpenAI()
307
- chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
308
- summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
 
 
309
  print(i18n("总结") + f": {summary}")
310
  chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
311
  return chatbot, status
@@ -326,9 +346,12 @@ class BaseLLMModel:
326
  msg = "索引获取成功,生成回答中……"
327
  logging.info(msg)
328
  with retrieve_proxy():
329
- retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold",search_kwargs={"k":6, "score_threshold": 0.5})
330
- relevant_documents = retriever.get_relevant_documents(real_inputs)
331
- reference_results = [[d.page_content.strip("�"), os.path.basename(d.metadata["source"])] for d in relevant_documents]
 
 
 
332
  reference_results = add_source_numbers(reference_results)
333
  display_append = add_details(reference_results)
334
  display_append = "\n\n" + "".join(display_append)
@@ -355,7 +378,8 @@ class BaseLLMModel:
355
  )
356
  reference_results = add_source_numbers(reference_results)
357
  # display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
358
- display_append = '<div class = "source-a">' + "".join(display_append) + '</div>'
 
359
  real_inputs = (
360
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
361
  .replace("{query}", real_inputs)
@@ -379,14 +403,16 @@ class BaseLLMModel:
379
 
380
  status_text = "开始生成回答……"
381
  logging.info(
382
- "用户" + f"{self.user_identifier}" + "的输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
 
383
  )
384
  if should_check_token_count:
385
  yield chatbot + [(inputs, "")], status_text
386
  if reply_language == "跟随问题语言(不稳定)":
387
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
388
 
389
- limited_context, fake_inputs, display_append, inputs, chatbot = self.prepare_inputs(real_inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language, chatbot=chatbot)
 
390
  yield chatbot + [(fake_inputs, "")], status_text
391
 
392
  if (
@@ -587,7 +613,8 @@ class BaseLLMModel:
587
  self.history = []
588
  self.all_token_counts = []
589
  self.interrupted = False
590
- pathlib.Path(os.path.join(HISTORY_DIR, self.user_identifier, new_auto_history_filename(os.path.join(HISTORY_DIR, self.user_identifier)))).touch()
 
591
  return [], self.token_message([0])
592
 
593
  def delete_first_conversation(self):
@@ -630,7 +657,8 @@ class BaseLLMModel:
630
 
631
  def auto_save(self, chatbot):
632
  history_file_path = get_history_filepath(self.user_identifier)
633
- save_file(history_file_path, self.system_prompt, self.history, chatbot, self.user_identifier)
 
634
 
635
  def export_markdown(self, filename, chatbot, user_name):
636
  if filename == "":
@@ -646,7 +674,8 @@ class BaseLLMModel:
646
  filename = filename.name
647
  try:
648
  if "/" not in filename:
649
- history_file_path = os.path.join(HISTORY_DIR, user_name, filename)
 
650
  else:
651
  history_file_path = filename
652
  with open(history_file_path, "r", encoding="utf-8") as f:
@@ -695,10 +724,10 @@ class BaseLLMModel:
695
  self.reset()
696
  return self.system_prompt, gr.update()
697
  history_file_path = get_history_filepath(self.user_identifier)
698
- filename, system_prompt, chatbot = self.load_chat_history(history_file_path, self.user_identifier)
 
699
  return system_prompt, chatbot
700
 
701
-
702
  def like(self):
703
  """like the last response, implement if needed
704
  """
@@ -708,3 +737,47 @@ class BaseLLMModel:
708
  """dislike the last response, implement if needed
709
  """
710
  return gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  from langchain.schema import AgentAction, AgentFinish, LLMResult
30
  from threading import Thread, Condition
31
  from collections import deque
32
+ from langchain.chat_models.base import BaseChatModel
33
+ from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
34
 
35
  from ..presets import *
36
  from ..index_func import *
 
38
  from .. import shared
39
  from ..config import retrieve_proxy
40
 
41
+
42
  class CallbackToIterator:
43
  def __init__(self):
44
  self.queue = deque()
 
55
 
56
  def __next__(self):
57
  with self.cond:
58
+ # Wait for a value to be added to the queue.
59
+ while not self.queue and not self.finished:
60
  self.cond.wait()
61
  if not self.queue:
62
  raise StopIteration()
 
67
  self.finished = True
68
  self.cond.notify() # Wake up the generator if it's waiting.
69
 
70
+
71
  def get_action_description(text):
72
  match = re.search('```(.*?)```', text, re.S)
73
  json_text = match.group(1)
 
81
  else:
82
  return ""
83
 
84
+
85
  class ChuanhuCallbackHandler(BaseCallbackHandler):
86
 
87
  def __init__(self, callback) -> None:
 
123
  """Run on new LLM token. Only available when streaming is enabled."""
124
  self.callback(token)
125
 
126
+ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any) -> Any:
127
+ """Run when a chat model starts running."""
128
+ pass
129
+
130
 
131
  class ModelType(Enum):
132
  Unknown = -1
 
140
  Minimax = 7
141
  ChuanhuAgent = 8
142
  GooglePaLM = 9
143
+ LangchainChat = 10
144
 
145
  @classmethod
146
  def get_type(cls, model_name: str):
 
166
  model_type = ModelType.ChuanhuAgent
167
  elif "palm" in model_name_lower:
168
  model_type = ModelType.GooglePaLM
169
+ elif "azure" or "api" in model_name_lower:
170
+ model_type = ModelType.LangchainChat
171
  else:
172
  model_type = ModelType.Unknown
173
  return model_type
 
177
  def __init__(
178
  self,
179
  model_name,
180
+ system_prompt=INITIAL_SYSTEM_PROMPT,
181
  temperature=1.0,
182
  top_p=1.0,
183
  n_choices=1,
 
217
  conversations are stored in self.history, with the most recent question, in OpenAI format
218
  should return a generator, each time give the next word (str) in the answer
219
  """
220
+ logging.warning(
221
+ "stream predict not implemented, using at once predict instead")
222
  response, _ = self.get_answer_at_once()
223
  yield response
224
 
 
229
  the answer (str)
230
  total token count (int)
231
  """
232
+ logging.warning(
233
+ "at once predict not implemented, using stream predict instead")
234
  response_iter = self.get_answer_stream_iter()
235
  count = 0
236
  for response in response_iter:
 
291
  self.history[-2] = construct_user(fake_input)
292
  chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
293
  if fake_input is not None:
294
+ self.all_token_counts[-1] += count_token(
295
+ construct_assistant(ai_reply))
296
  else:
297
+ self.all_token_counts[-1] = total_token_count - \
298
+ sum(self.all_token_counts)
299
  status_text = self.token_message()
300
  return chatbot, status_text
301
 
 
319
  from langchain.chat_models import ChatOpenAI
320
  from langchain.callbacks import StdOutCallbackHandler
321
  prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
322
+ PROMPT = PromptTemplate(
323
+ template=prompt_template, input_variables=["text"])
324
  llm = ChatOpenAI()
325
+ chain = load_summarize_chain(
326
+ llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
327
+ summary = chain({"input_documents": list(index.docstore.__dict__[
328
+ "_dict"].values())}, return_only_outputs=True)["output_text"]
329
  print(i18n("总结") + f": {summary}")
330
  chatbot.append([i18n("上传了")+str(len(files))+"个文件", summary])
331
  return chatbot, status
 
346
  msg = "索引获取成功,生成回答中……"
347
  logging.info(msg)
348
  with retrieve_proxy():
349
+ retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold", search_kwargs={
350
+ "k": 6, "score_threshold": 0.5})
351
+ relevant_documents = retriever.get_relevant_documents(
352
+ real_inputs)
353
+ reference_results = [[d.page_content.strip("�"), os.path.basename(
354
+ d.metadata["source"])] for d in relevant_documents]
355
  reference_results = add_source_numbers(reference_results)
356
  display_append = add_details(reference_results)
357
  display_append = "\n\n" + "".join(display_append)
 
378
  )
379
  reference_results = add_source_numbers(reference_results)
380
  # display_append = "<ol>\n\n" + "".join(display_append) + "</ol>"
381
+ display_append = '<div class = "source-a">' + \
382
+ "".join(display_append) + '</div>'
383
  real_inputs = (
384
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
385
  .replace("{query}", real_inputs)
 
403
 
404
  status_text = "开始生成回答……"
405
  logging.info(
406
+ "用户" + f"{self.user_identifier}" + "的输入为:" +
407
+ colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
408
  )
409
  if should_check_token_count:
410
  yield chatbot + [(inputs, "")], status_text
411
  if reply_language == "跟随问题语言(不稳定)":
412
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
413
 
414
+ limited_context, fake_inputs, display_append, inputs, chatbot = self.prepare_inputs(
415
+ real_inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language, chatbot=chatbot)
416
  yield chatbot + [(fake_inputs, "")], status_text
417
 
418
  if (
 
613
  self.history = []
614
  self.all_token_counts = []
615
  self.interrupted = False
616
+ pathlib.Path(os.path.join(HISTORY_DIR, self.user_identifier, new_auto_history_filename(
617
+ os.path.join(HISTORY_DIR, self.user_identifier)))).touch()
618
  return [], self.token_message([0])
619
 
620
  def delete_first_conversation(self):
 
657
 
658
  def auto_save(self, chatbot):
659
  history_file_path = get_history_filepath(self.user_identifier)
660
+ save_file(history_file_path, self.system_prompt,
661
+ self.history, chatbot, self.user_identifier)
662
 
663
  def export_markdown(self, filename, chatbot, user_name):
664
  if filename == "":
 
674
  filename = filename.name
675
  try:
676
  if "/" not in filename:
677
+ history_file_path = os.path.join(
678
+ HISTORY_DIR, user_name, filename)
679
  else:
680
  history_file_path = filename
681
  with open(history_file_path, "r", encoding="utf-8") as f:
 
724
  self.reset()
725
  return self.system_prompt, gr.update()
726
  history_file_path = get_history_filepath(self.user_identifier)
727
+ filename, system_prompt, chatbot = self.load_chat_history(
728
+ history_file_path, self.user_identifier)
729
  return system_prompt, chatbot
730
 
 
731
  def like(self):
732
  """like the last response, implement if needed
733
  """
 
737
  """dislike the last response, implement if needed
738
  """
739
  return gr.update()
740
+
741
+
742
+ class Base_Chat_Langchain_Client(BaseLLMModel):
743
+ def __init__(self, model_name, user_name=""):
744
+ super().__init__(model_name, user=user_name)
745
+ self.need_api_key = False
746
+ self.model = self.setup_model()
747
+
748
+ def setup_model(self):
749
+ # inplement this to setup the model then return it
750
+ pass
751
+
752
+ def _get_langchain_style_history(self):
753
+ history = [SystemMessage(content=self.system_prompt)]
754
+ for i in self.history:
755
+ if i["role"] == "user":
756
+ history.append(HumanMessage(content=i["content"]))
757
+ elif i["role"] == "assistant":
758
+ history.append(AIMessage(content=i["content"]))
759
+ return history
760
+
761
+ def get_answer_at_once(self):
762
+ assert isinstance(
763
+ self.model, BaseChatModel), "model is not instance of LangChain BaseChatModel"
764
+ history = self._get_langchain_style_history()
765
+ response = self.model.generate(history)
766
+ return response.content, sum(response.content)
767
+
768
+ def get_answer_stream_iter(self):
769
+ it = CallbackToIterator()
770
+ assert isinstance(
771
+ self.model, BaseChatModel), "model is not instance of LangChain BaseChatModel"
772
+ history = self._get_langchain_style_history()
773
+
774
+ def thread_func():
775
+ self.model(messages=history, callbacks=[
776
+ ChuanhuCallbackHandler(it.callback)])
777
+ it.finish()
778
+ t = Thread(target=thread_func)
779
+ t.start()
780
+ partial_text = ""
781
+ for value in it:
782
+ partial_text += value
783
+ yield partial_text
modules/models/models.py CHANGED
@@ -616,6 +616,9 @@ def get_model(
616
  from .Google_PaLM import Google_PaLM_Client
617
  access_key = os.environ.get("GOOGLE_PALM_API_KEY")
618
  model = Google_PaLM_Client(model_name, access_key, user_name=user_name)
 
 
 
619
  elif model_type == ModelType.Unknown:
620
  raise ValueError(f"未知模型: {model_name}")
621
  logging.info(msg)
 
616
  from .Google_PaLM import Google_PaLM_Client
617
  access_key = os.environ.get("GOOGLE_PALM_API_KEY")
618
  model = Google_PaLM_Client(model_name, access_key, user_name=user_name)
619
+ elif model_type == ModelType.LangchainChat:
620
+ from .azure import Azure_OpenAI_Client
621
+ model = Azure_OpenAI_Client(model_name, user_name=user_name)
622
  elif model_type == ModelType.Unknown:
623
  raise ValueError(f"未知模型: {model_name}")
624
  logging.info(msg)
modules/presets.py CHANGED
@@ -62,6 +62,7 @@ ONLINE_MODELS = [
62
  "川虎助理 Pro",
63
  "GooglePaLM",
64
  "xmchat",
 
65
  "yuanai-1.0-base_10B",
66
  "yuanai-1.0-translate",
67
  "yuanai-1.0-dialog",
 
62
  "川虎助理 Pro",
63
  "GooglePaLM",
64
  "xmchat",
65
+ "Azure OpenAI",
66
  "yuanai-1.0-base_10B",
67
  "yuanai-1.0-translate",
68
  "yuanai-1.0-dialog",